From: Mike Bayer Date: Mon, 27 Apr 2020 16:58:12 +0000 (-0400) Subject: Convert execution to move through Session X-Git-Tag: rel_1_4_0b1~302 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6930dfc032c3f9f474e71ab4e021c0ef8384930e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Convert execution to move through Session This patch replaces the ORM execution flow with a single pathway through Session.execute() for all queries, including Core and ORM. Currently included is full support for ORM Query, Query.from_statement(), select(), as well as the baked query and horizontal shard systems. Initial changes have also been made to the dogpile caching example, which like baked query makes use of a new ORM-specific execution hook that replaces the use of both QueryEvents.before_compile() as well as Query._execute_and_instances() as the central ORM interception hooks. select() and Query() constructs alike can be passed to Session.execute() where they will return ORM results in a Results object. This API is currently used internally by Query. Full support for Session.execute()->results to behave in a fully 2.0 fashion will be in later changesets. bulk update/delete with ORM support will also be delivered via the update() and delete() constructs, however these have not yet been adapted to the new system and may follow in a subsequent update. Performance is also beginning to lag as of this commit and some previous ones. It is hoped that a few central functions such as the coercions functions can be rewritten in C to re-gain performance. Additionally, query caching is now available and some subsequent patches will attempt to cache more of the per-execution work from the ORM layer, e.g. column getters and adapters. This patch also contains initial "turn on" of the caching system enginewide via the query_cache_size parameter to create_engine(). Still defaulting at zero for "no caching". The caching system still needs adjustments in order to gain adequate performance. Change-Id: I047a7ebb26aa85dc01f6789fac2bff561dcd555d --- diff --git a/doc/build/orm/extensions/baked.rst b/doc/build/orm/extensions/baked.rst index 8614cd048a..951f35e6ae 100644 --- a/doc/build/orm/extensions/baked.rst +++ b/doc/build/orm/extensions/baked.rst @@ -20,6 +20,11 @@ cache the **return results** from the database. A technique that demonstrates the caching of the SQL calls and result sets themselves is available in :ref:`examples_caching`. +.. deprecated:: 1.4 SQLAlchemy 1.4 and 2.0 feature an all-new direct query + caching system that removes the need for the :class:`.BakedQuery` system. + Caching is now built in to all Core and ORM queries using the + :paramref:`.create_engine.query_cache_size` parameter. + .. versionadded:: 1.0.0 diff --git a/doc/build/orm/session_api.rst b/doc/build/orm/session_api.rst index e247a8de78..849472e9f0 100644 --- a/doc/build/orm/session_api.rst +++ b/doc/build/orm/session_api.rst @@ -1,4 +1,4 @@ -.. currentmodule:: sqlalchemy.orm.session +.. currentmodule:: sqlalchemy.orm Session API =========== @@ -10,11 +10,45 @@ Session and sessionmaker() :members: :inherited-members: -.. autoclass:: sqlalchemy.orm.session.Session +.. autoclass:: ORMExecuteState + :members: + + + .. attribute:: session + + The :class:`_orm.Session` in use. + + .. attribute:: statement + + The SQL statement being invoked. For an ORM selection as would + be retrieved from :class:`_orm.Query`, this is an instance of + :class:`_future.select` that was generated from the ORM query. + + .. attribute:: parameters + + Dictionary of parameters that was passed to :meth:`_orm.Session.execute`. + + .. attribute:: execution_options + + Dictionary of execution options passed to :meth:`_orm.Session.execute`. + Note that this dictionary does not include execution options that may + be associated with the statement itself, or with any underlying + :class:`_engine.Connection` that may be used to invoke this statement. + + .. attribute:: bind_arguments + + The dictionary passed as the + :paramref:`_orm.Session.execute.bind_arguments` dictionary. This + dictionary may be used by extensions to :class:`_orm.Session` to pass + arguments that will assist in determining amongst a set of database + connections which one should be used to invoke this statement. + + +.. autoclass:: Session :members: :inherited-members: -.. autoclass:: sqlalchemy.orm.session.SessionTransaction +.. autoclass:: SessionTransaction :members: Session Utilities diff --git a/examples/dogpile_caching/advanced.py b/examples/dogpile_caching/advanced.py index d2ef825562..e72921ba4f 100644 --- a/examples/dogpile_caching/advanced.py +++ b/examples/dogpile_caching/advanced.py @@ -5,6 +5,7 @@ including front-end loading, cache invalidation and collection caching. from .caching_query import FromCache from .caching_query import RelationshipCache +from .environment import cache from .environment import Session from .model import cache_address_bits from .model import Person @@ -48,7 +49,8 @@ def load_name_range(start, end, invalidate=False): # if requested, invalidate the cache on current criterion. if invalidate: - q.invalidate() + cache.invalidate(q, {}, FromCache("default", "name_range")) + cache.invalidate(q, {}, RelationshipCache(Person.addresses, "default")) return q.all() diff --git a/examples/dogpile_caching/caching_query.py b/examples/dogpile_caching/caching_query.py index d6e1435b0a..54f712a11b 100644 --- a/examples/dogpile_caching/caching_query.py +++ b/examples/dogpile_caching/caching_query.py @@ -2,16 +2,18 @@ which allow the usage of Dogpile caching with SQLAlchemy. Introduces a query option called FromCache. +.. versionchanged:: 1.4 the caching approach has been altered to work + based on a session event. + + The three new concepts introduced here are: - * CachingQuery - a Query subclass that caches and + * ORMCache - an extension for an ORM :class:`.Session` retrieves results in/from dogpile.cache. * FromCache - a query option that establishes caching parameters on a Query * RelationshipCache - a variant of FromCache which is specific to a query invoked during a lazy load. - * _params_from_query - extracts value parameters from - a Query. The rest of what's here are standard SQLAlchemy and dogpile.cache constructs. @@ -19,165 +21,97 @@ dogpile.cache constructs. """ from dogpile.cache.api import NO_VALUE -from sqlalchemy.orm.interfaces import MapperOption -from sqlalchemy.orm.query import Query +from sqlalchemy import event +from sqlalchemy.orm import loading +from sqlalchemy.orm.interfaces import UserDefinedOption -class CachingQuery(Query): - """A Query subclass which optionally loads full results from a dogpile - cache region. +class ORMCache(object): - The CachingQuery optionally stores additional state that allows it to - consult a dogpile.cache cache before accessing the database, in the form of - a FromCache or RelationshipCache object. Each of these objects refer to - the name of a :class:`dogpile.cache.Region` that's been configured and - stored in a lookup dictionary. When such an object has associated itself - with the CachingQuery, the corresponding :class:`dogpile.cache.Region` is - used to locate a cached result. If none is present, then the Query is - invoked normally, the results being cached. + """An add-on for an ORM :class:`.Session` optionally loads full results + from a dogpile cache region. - The FromCache and RelationshipCache mapper options below represent - the "public" method of configuring this state upon the CachingQuery. """ - def __init__(self, regions, *args, **kw): + def __init__(self, regions): self.cache_regions = regions - Query.__init__(self, *args, **kw) - - # NOTE: as of 1.4 don't override __iter__() anymore, the result object - # cannot be cached at that level. + self._statement_cache = {} + + def listen_on_session(self, session_factory): + event.listen(session_factory, "do_orm_execute", self._do_orm_execute) + + def _do_orm_execute(self, orm_context): + + for opt in orm_context.user_defined_options: + if isinstance(opt, RelationshipCache): + opt = opt._process_orm_context(orm_context) + if opt is None: + continue + + if isinstance(opt, FromCache): + dogpile_region = self.cache_regions[opt.region] + + our_cache_key = opt._generate_cache_key( + orm_context.statement, orm_context.parameters, self + ) + + if opt.ignore_expiration: + cached_value = dogpile_region.get( + our_cache_key, + expiration_time=opt.expiration_time, + ignore_expiration=opt.ignore_expiration, + ) + else: + + def createfunc(): + return orm_context.invoke_statement().freeze() + + cached_value = dogpile_region.get_or_create( + our_cache_key, + createfunc, + expiration_time=opt.expiration_time, + ) + + if cached_value is NO_VALUE: + # keyerror? this is bigger than a keyerror... + raise KeyError() + + orm_result = loading.merge_frozen_result( + orm_context.session, + orm_context.statement, + cached_value, + load=False, + ) + return orm_result() - def _execute_and_instances(self, context, **kw): - """override _execute_and_instances to pull results from dogpile - if the query is invoked directly from an external context. + else: + return None - This method is necessary in order to maintain compatibility - with the "baked query" system now used by default in some - relationship loader scenarios. Note also the - RelationshipCache._generate_cache_key method which enables - the baked query to be used within lazy loads. + def invalidate(self, statement, parameters, opt): + """Invalidate the cache value represented by a statement.""" - .. versionadded:: 1.2.7 + statement = statement.__clause_element__() - .. versionchanged:: 1.4 Added ``**kw`` arguments to the signature. + dogpile_region = self.cache_regions[opt.region] - """ - super_ = super(CachingQuery, self) - - if hasattr(self, "_cache_region"): - # special logic called when the Query._execute_and_instances() - # method is called directly from the baked query - return self.get_value( - createfunc=lambda: super_._execute_and_instances( - context, **kw - ).freeze() - ) - else: - return super_._execute_and_instances(context, **kw) + cache_key = opt._generate_cache_key(statement, parameters, self) - def _get_cache_plus_key(self): - """Return a cache region plus key.""" + dogpile_region.delete(cache_key) - dogpile_region = self.cache_regions[self._cache_region.region] - if self._cache_region.cache_key: - key = self._cache_region.cache_key - else: - key = _key_from_query(self) - return dogpile_region, key - def invalidate(self): - """Invalidate the cache value represented by this Query.""" +class FromCache(UserDefinedOption): + """Specifies that a Query should load results from a cache.""" - dogpile_region, cache_key = self._get_cache_plus_key() - dogpile_region.delete(cache_key) + propagate_to_loaders = False - def get_value( + def __init__( self, - merge=True, - createfunc=None, + region="default", + cache_key=None, expiration_time=None, ignore_expiration=False, ): - """Return the value from the cache for this query. - - Raise KeyError if no value present and no - createfunc specified. - - """ - dogpile_region, cache_key = self._get_cache_plus_key() - - # ignore_expiration means, if the value is in the cache - # but is expired, return it anyway. This doesn't make sense - # with createfunc, which says, if the value is expired, generate - # a new value. - assert ( - not ignore_expiration or not createfunc - ), "Can't ignore expiration and also provide createfunc" - - if ignore_expiration or not createfunc: - cached_value = dogpile_region.get( - cache_key, - expiration_time=expiration_time, - ignore_expiration=ignore_expiration, - ) - else: - cached_value = dogpile_region.get_or_create( - cache_key, createfunc, expiration_time=expiration_time - ) - if cached_value is NO_VALUE: - raise KeyError(cache_key) - - # in 1.4 the cached value is a FrozenResult. merge_result - # accommodates this directly and updates the ORM entities inside - # the object to be merged. - # TODO: should this broken into merge_frozen_result / merge_iterator? - if merge: - cached_value = self.merge_result(cached_value, load=False) - return cached_value() - - def set_value(self, value): - """Set the value in the cache for this query.""" - - dogpile_region, cache_key = self._get_cache_plus_key() - dogpile_region.set(cache_key, value) - - -def query_callable(regions, query_cls=CachingQuery): - def query(*arg, **kw): - return query_cls(regions, *arg, **kw) - - return query - - -def _key_from_query(query, qualifier=None): - """Given a Query, create a cache key. - - There are many approaches to this; here we use the simplest, - which is to create an md5 hash of the text of the SQL statement, - combined with stringified versions of all the bound parameters - within it. There's a bit of a performance hit with - compiling out "query.statement" here; other approaches include - setting up an explicit cache key with a particular Query, - then combining that with the bound parameter values. - - """ - - stmt = query.with_labels().statement - compiled = stmt.compile() - params = compiled.params - - # here we return the key as a long string. our "key mangler" - # set up with the region will boil it down to an md5. - return " ".join([str(compiled)] + [str(params[k]) for k in sorted(params)]) - - -class FromCache(MapperOption): - """Specifies that a Query should load results from a cache.""" - - propagate_to_loaders = False - - def __init__(self, region="default", cache_key=None): """Construct a new FromCache. :param region: the cache region. Should be a @@ -193,19 +127,34 @@ class FromCache(MapperOption): """ self.region = region self.cache_key = cache_key + self.expiration_time = expiration_time + self.ignore_expiration = ignore_expiration + + def _generate_cache_key(self, statement, parameters, orm_cache): + statement_cache_key = statement._generate_cache_key() + + key = statement_cache_key.to_offline_string( + orm_cache._statement_cache, parameters + ) + repr(self.cache_key) - def process_query(self, query): - """Process a Query during normal loading operation.""" - query._cache_region = self + # print("here's our key...%s" % key) + return key -class RelationshipCache(MapperOption): +class RelationshipCache(FromCache): """Specifies that a Query as called within a "lazy load" should load results from a cache.""" propagate_to_loaders = True - def __init__(self, attribute, region="default", cache_key=None): + def __init__( + self, + attribute, + region="default", + cache_key=None, + expiration_time=None, + ignore_expiration=False, + ): """Construct a new RelationshipCache. :param attribute: A Class.attribute which @@ -221,19 +170,17 @@ class RelationshipCache(MapperOption): """ self.region = region self.cache_key = cache_key + self.expiration_time = expiration_time + self.ignore_expiration = ignore_expiration self._relationship_options = { (attribute.property.parent.class_, attribute.property.key): self } - def process_query_conditionally(self, query): - """Process a Query that is used within a lazy loader. - - (the process_query_conditionally() method is a SQLAlchemy - hook invoked only within lazyload.) + def _process_orm_context(self, orm_context): + current_path = orm_context.loader_strategy_path - """ - if query._current_path: - mapper, prop = query._current_path[-2:] + if current_path: + mapper, prop = current_path[-2:] key = prop.key for cls in mapper.class_.__mro__: @@ -241,8 +188,7 @@ class RelationshipCache(MapperOption): relationship_option = self._relationship_options[ (cls, key) ] - query._cache_region = relationship_option - break + return relationship_option def and_(self, option): """Chain another RelationshipCache option to this one. @@ -254,16 +200,3 @@ class RelationshipCache(MapperOption): """ self._relationship_options.update(option._relationship_options) return self - - def _generate_cache_key(self, path): - """Indicate to the lazy-loader strategy that a "baked" query - may be used by returning ``None``. - - If this method is omitted, the default implementation of - :class:`.MapperOption._generate_cache_key` takes place, which - returns ``False`` to disable the "baked" query from being used. - - .. versionadded:: 1.2.7 - - """ - return None diff --git a/examples/dogpile_caching/environment.py b/examples/dogpile_caching/environment.py index 723ee653d5..7f4f7e7a17 100644 --- a/examples/dogpile_caching/environment.py +++ b/examples/dogpile_caching/environment.py @@ -23,13 +23,11 @@ if py2k: # dogpile cache regions. A home base for cache configurations. regions = {} +# scoped_session. +Session = scoped_session(sessionmaker()) -# scoped_session. Apply our custom CachingQuery class to it, -# using a callable that will associate the dictionary -# of regions with the Query. -Session = scoped_session( - sessionmaker(query_cls=caching_query.query_callable(regions)) -) +cache = caching_query.ORMCache(regions) +cache.listen_on_session(Session) # global declarative base class. Base = declarative_base() diff --git a/examples/dogpile_caching/helloworld.py b/examples/dogpile_caching/helloworld.py index 6b03afbdbb..6e79fc3fa4 100644 --- a/examples/dogpile_caching/helloworld.py +++ b/examples/dogpile_caching/helloworld.py @@ -3,6 +3,7 @@ """ from .caching_query import FromCache +from .environment import cache from .environment import Session from .model import Person @@ -57,10 +58,19 @@ people_two_through_twelve = ( # same list of objects to be loaded, and the same parameters in the # same order, then call invalidate(). print("invalidating everything") -Session.query(Person).options(FromCache("default")).invalidate() -Session.query(Person).options(FromCache("default")).filter( - Person.name.between("person 02", "person 12") -).invalidate() -Session.query(Person).options(FromCache("default", "people_on_range")).filter( - Person.name.between("person 05", "person 15") -).invalidate() + +cache.invalidate(Session.query(Person), {}, FromCache("default")) +cache.invalidate( + Session.query(Person).filter( + Person.name.between("person 02", "person 12") + ), + {}, + FromCache("default"), +) +cache.invalidate( + Session.query(Person).filter( + Person.name.between("person 05", "person 15") + ), + {}, + FromCache("default", "people_on_range"), +) diff --git a/examples/dogpile_caching/local_session_caching.py b/examples/dogpile_caching/local_session_caching.py index 1700c7a636..8f505ead72 100644 --- a/examples/dogpile_caching/local_session_caching.py +++ b/examples/dogpile_caching/local_session_caching.py @@ -75,8 +75,8 @@ if __name__ == "__main__": # of "person 10" q = ( Session.query(Person) - .options(FromCache("local_session")) .filter(Person.name == "person 10") + .execution_options(cache_options=FromCache("local_session")) ) # load from DB diff --git a/examples/performance/short_selects.py b/examples/performance/short_selects.py index db8ab8789c..38bc1508a1 100644 --- a/examples/performance/short_selects.py +++ b/examples/performance/short_selects.py @@ -82,6 +82,29 @@ def test_orm_query_cols_only(n): ).one() +cache = {} + + +@Profiler.profile +def test_cached_orm_query(n): + """test new style cached queries of the full entity.""" + s = Session(bind=engine) + for id_ in random.sample(ids, n): + stmt = s.query(Customer).filter(Customer.id == id_) + s.execute(stmt, execution_options={"compiled_cache": cache}).one() + + +@Profiler.profile +def test_cached_orm_query_cols_only(n): + """test new style cached queries of the full entity.""" + s = Session(bind=engine) + for id_ in random.sample(ids, n): + stmt = s.query( + Customer.id, Customer.name, Customer.description + ).filter(Customer.id == id_) + s.execute(stmt, execution_options={"compiled_cache": cache}).one() + + @Profiler.profile def test_baked_query(n): """test a baked query of the full entity.""" diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c index ff6cadac02..ed6f57470d 100644 --- a/lib/sqlalchemy/cextension/resultproxy.c +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -651,6 +651,7 @@ static int BaseRow_setkeystyle(BaseRow *self, PyObject *value, void *closure) { if (value == NULL) { + PyErr_SetString( PyExc_TypeError, "Cannot delete the 'key_style' attribute"); diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 05c34c1712..3345d555f6 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1704,12 +1704,14 @@ class MSSQLCompiler(compiler.SQLCompiler): self.process(element.typeclause, **kw), ) - def visit_select(self, select, **kwargs): + def translate_select_structure(self, select_stmt, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. MSSQL 2012 and above are excluded """ + select = select_stmt + if ( not self.dialect._supports_offset_fetch and ( @@ -1741,7 +1743,7 @@ class MSSQLCompiler(compiler.SQLCompiler): limit_clause = select._limit_clause offset_clause = select._offset_clause - kwargs["select_wraps_for"] = select + select = select._generate() select._mssql_visit = True select = ( @@ -1766,9 +1768,9 @@ class MSSQLCompiler(compiler.SQLCompiler): ) else: limitselect = limitselect.where(mssql_rn <= (limit_clause)) - return self.process(limitselect, **kwargs) + return limitselect else: - return compiler.SQLCompiler.visit_select(self, select, **kwargs) + return select @_with_legacy_schema_aliasing def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index dd7d6a4d1c..481ea72633 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -975,16 +975,8 @@ class OracleCompiler(compiler.SQLCompiler): return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) - def _TODO_visit_compound_select(self, select): - """Need to determine how to get ``LIMIT``/``OFFSET`` into a - ``UNION`` for Oracle. - """ - pass - - def visit_select(self, select, **kwargs): - """Look for ``LIMIT`` and OFFSET in a select statement, and if - so tries to wrap it in a subquery with ``rownum`` criterion. - """ + def translate_select_structure(self, select_stmt, **kwargs): + select = select_stmt if not getattr(select, "_oracle_visit", None): if not self.dialect.use_ansi: @@ -1003,7 +995,7 @@ class OracleCompiler(compiler.SQLCompiler): # https://blogs.oracle.com/oraclemagazine/\ # on-rownum-and-limiting-results - kwargs["select_wraps_for"] = orig_select = select + orig_select = select select = select._generate() select._oracle_visit = True @@ -1136,7 +1128,7 @@ class OracleCompiler(compiler.SQLCompiler): offsetselect._for_update_arg = for_update select = offsetselect - return compiler.SQLCompiler.visit_select(self, select, **kwargs) + return select def limit_clause(self, select, **kw): return "" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index ee02899f60..0193ea47cc 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -225,14 +225,10 @@ class Connection(Connectable): A dictionary where :class:`.Compiled` objects will be cached when the :class:`_engine.Connection` compiles a clause - expression into a :class:`.Compiled` object. - It is the user's responsibility to - manage the size of this dictionary, which will have keys - corresponding to the dialect, clause element, the column - names within the VALUES or SET clause of an INSERT or UPDATE, - as well as the "batch" mode for an INSERT or UPDATE statement. - The format of this dictionary is not guaranteed to stay the - same in future releases. + expression into a :class:`.Compiled` object. This dictionary will + supersede the statement cache that may be configured on the + :class:`_engine.Engine` itself. If set to None, caching + is disabled, even if the engine has a configured cache size. Note that the ORM makes use of its own "compiled" caches for some operations, including flush operations. The caching @@ -1159,13 +1155,17 @@ class Connection(Connectable): schema_translate_map = exec_opts.get("schema_translate_map", None) - if "compiled_cache" in exec_opts: + compiled_cache = exec_opts.get( + "compiled_cache", self.dialect._compiled_cache + ) + + if compiled_cache is not None: elem_cache_key = elem._generate_cache_key() else: elem_cache_key = None if elem_cache_key: - cache_key, extracted_params = elem_cache_key + cache_key, extracted_params, _ = elem_cache_key key = ( dialect, cache_key, @@ -1173,8 +1173,7 @@ class Connection(Connectable): bool(schema_translate_map), len(distilled_params) > 1, ) - cache = exec_opts["compiled_cache"] - compiled_sql = cache.get(key) + compiled_sql = compiled_cache.get(key) if compiled_sql is None: compiled_sql = elem.compile( @@ -1185,12 +1184,8 @@ class Connection(Connectable): schema_translate_map=schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, - compile_state_factories=exec_opts.get( - "compile_state_factories", None - ), ) - cache[key] = compiled_sql - + compiled_cache[key] = compiled_sql else: extracted_params = None compiled_sql = elem.compile( @@ -1199,9 +1194,6 @@ class Connection(Connectable): inline=len(distilled_params) > 1, schema_translate_map=schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, - compile_state_factories=exec_opts.get( - "compile_state_factories", None - ), ) ret = self._execute_context( @@ -1430,18 +1422,35 @@ class Connection(Connectable): ) if self._echo: + self.engine.logger.info(statement) + + # stats = context._get_cache_stats() + if not self.engine.hide_parameters: + # TODO: I love the stats but a ton of tests that are hardcoded. + # to certain log output are failing. self.engine.logger.info( "%r", sql_util._repr_params( parameters, batches=10, ismulti=context.executemany ), ) + # self.engine.logger.info( + # "[%s] %r", + # stats, + # sql_util._repr_params( + # parameters, batches=10, ismulti=context.executemany + # ), + # ) else: self.engine.logger.info( "[SQL parameters hidden due to hide_parameters=True]" ) + # self.engine.logger.info( + # "[%s] [SQL parameters hidden due to hide_parameters=True]" + # % (stats,) + # ) evt_handled = False try: @@ -1502,19 +1511,14 @@ class Connection(Connectable): # for "connectionless" execution, we have to close this # Connection after the statement is complete. - if branched.should_close_with_result: + # legacy stuff. + if branched.should_close_with_result and context._soft_closed: assert not self._is_future assert not context._is_future_result # CursorResult already exhausted rows / has no rows. - # close us now. note this is where we call .close() - # on the "branched" connection if we're doing that. - if result._soft_closed: - branched.close() - else: - # CursorResult will close this Connection when no more - # rows to fetch. - result._autoclose_connection = True + # close us now + branched.close() except BaseException as e: self._handle_dbapi_exception( e, statement, parameters, cursor, context diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index e683b6297c..4c912349ea 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -435,6 +435,23 @@ def create_engine(url, **kwargs): .. versionadded:: 1.2.3 + :param query_cache_size: size of the cache used to cache the SQL string + form of queries. Defaults to zero, which disables caching. + + Caching is accomplished on a per-statement basis by generating a + cache key that represents the statement's structure, then generating + string SQL for the current dialect only if that key is not present + in the cache. All statements support caching, however some features + such as an INSERT with a large set of parameters will intentionally + bypass the cache. SQL logging will indicate statistics for each + statement whether or not it were pull from the cache. + + .. seealso:: + + ``engine_caching`` - TODO: this will be an upcoming section describing + the SQL caching system. + + .. versionadded:: 1.4 """ # noqa diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 8d1a1bb57f..fdbf826ed9 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -57,6 +57,9 @@ class CursorResultMetaData(ResultMetaData): returns_rows = True + def _has_key(self, key): + return key in self._keymap + def _for_freeze(self): return SimpleResultMetaData( self._keys, @@ -1203,6 +1206,7 @@ class BaseCursorResult(object): out_parameters = None _metadata = None + _metadata_from_cache = False _soft_closed = False closed = False @@ -1213,7 +1217,6 @@ class BaseCursorResult(object): obj = CursorResult(context) else: obj = LegacyCursorResult(context) - return obj def __init__(self, context): @@ -1247,8 +1250,9 @@ class BaseCursorResult(object): def _init_metadata(self, context, cursor_description): if context.compiled: if context.compiled._cached_metadata: - cached_md = context.compiled._cached_metadata - self._metadata = cached_md._adapt_to_context(context) + cached_md = self.context.compiled._cached_metadata + self._metadata = cached_md + self._metadata_from_cache = True else: self._metadata = ( diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index e30daaeb81..b5cb2a1b2c 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -16,6 +16,7 @@ as the base class for their own corresponding classes. import codecs import random import re +import time import weakref from . import cursor as _cursor @@ -226,6 +227,7 @@ class DefaultDialect(interfaces.Dialect): supports_native_boolean=None, max_identifier_length=None, label_length=None, + query_cache_size=0, # int() is because the @deprecated_params decorator cannot accommodate # the direct reference to the "NO_LINTING" object compiler_linting=int(compiler.NO_LINTING), @@ -257,6 +259,10 @@ class DefaultDialect(interfaces.Dialect): if supports_native_boolean is not None: self.supports_native_boolean = supports_native_boolean self.case_sensitive = case_sensitive + if query_cache_size != 0: + self._compiled_cache = util.LRUCache(query_cache_size) + else: + self._compiled_cache = None self._user_defined_max_identifier_length = max_identifier_length if self._user_defined_max_identifier_length: @@ -702,11 +708,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): result_column_struct = None returned_defaults = None execution_options = util.immutabledict() + + cache_stats = None + invoked_statement = None + _is_implicit_returning = False _is_explicit_returning = False _is_future_result = False _is_server_side = False + _soft_closed = False + # a hook for SQLite's translation of # result column names # NOTE: pyhive is using this hook, can't remove it :( @@ -1011,6 +1023,16 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor = self.create_cursor() return self + def _get_cache_stats(self): + if self.compiled is None: + return "raw SQL" + + now = time.time() + if self.compiled.cache_key is None: + return "gen %.5fs" % (now - self.compiled._gen_time,) + else: + return "cached %.5fs" % (now - self.compiled._gen_time,) + @util.memoized_property def engine(self): return self.root_connection.engine @@ -1234,6 +1256,33 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ): self._setup_out_parameters(result) + if not self._is_future_result: + conn = self.root_connection + assert not conn._is_future + + if not result._soft_closed and conn.should_close_with_result: + result._autoclose_connection = True + + self._soft_closed = result._soft_closed + + # result rewrite/ adapt step. two translations can occur here. + # one is if we are invoked against a cached statement, we want + # to rewrite the ResultMetaData to reflect the column objects + # that are in our current selectable, not the cached one. the + # other is, the CompileState can return an alternative Result + # object. Finally, CompileState might want to tell us to not + # actually do the ResultMetaData adapt step if it in fact has + # changed the selected columns in any case. + compiled = self.compiled + if compiled: + adapt_metadata = ( + result._metadata_from_cache + and not compiled._rewrites_selected_columns + ) + + if adapt_metadata: + result._metadata = result._metadata._adapt_to_context(self) + return result def _setup_out_parameters(self, result): diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 4e6b22820d..0ee80ede4c 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -56,6 +56,9 @@ class ResultMetaData(object): def keys(self): return RMKeyView(self) + def _has_key(self, key): + raise NotImplementedError() + def _for_freeze(self): raise NotImplementedError() @@ -171,6 +174,9 @@ class SimpleResultMetaData(ResultMetaData): self._processors = _processors + def _has_key(self, key): + return key in self._keymap + def _for_freeze(self): unique_filters = self._unique_filters if unique_filters and self._tuplefilter: @@ -287,6 +293,8 @@ class Result(InPlaceGenerative): _no_scalar_onerow = False _yield_per = None + _attributes = util.immutabledict() + def __init__(self, cursor_metadata): self._metadata = cursor_metadata @@ -548,10 +556,21 @@ class Result(InPlaceGenerative): self._generate_rows = True def _row_getter(self): - if self._source_supports_scalars and not self._generate_rows: - return None + if self._source_supports_scalars: + if not self._generate_rows: + return None + else: + _proc = self._process_row + + def process_row( + metadata, processors, keymap, key_style, scalar_obj + ): + return _proc( + metadata, processors, keymap, key_style, (scalar_obj,) + ) - process_row = self._process_row + else: + process_row = self._process_row key_style = self._process_row._default_key_style metadata = self._metadata @@ -771,16 +790,15 @@ class Result(InPlaceGenerative): uniques, strategy = self._unique_strategy def filterrows(make_row, rows, strategy, uniques): + if make_row: + rows = [make_row(row) for row in rows] + if strategy: made_rows = ( - (made_row, strategy(made_row)) - for made_row in [make_row(row) for row in rows] + (made_row, strategy(made_row)) for made_row in rows ) else: - made_rows = ( - (made_row, made_row) - for made_row in [make_row(row) for row in rows] - ) + made_rows = ((made_row, made_row) for made_row in rows) return [ made_row for made_row, sig_row in made_rows @@ -831,7 +849,8 @@ class Result(InPlaceGenerative): num = self._yield_per rows = self._fetchmany_impl(num) - rows = [make_row(row) for row in rows] + if make_row: + rows = [make_row(row) for row in rows] if post_creational_filter: rows = [post_creational_filter(row) for row in rows] return rows @@ -1114,24 +1133,42 @@ class FrozenResult(object): def __init__(self, result): self.metadata = result._metadata._for_freeze() self._post_creational_filter = result._post_creational_filter - self._source_supports_scalars = result._source_supports_scalars self._generate_rows = result._generate_rows + self._source_supports_scalars = result._source_supports_scalars + self._attributes = result._attributes result._post_creational_filter = None - self.data = result.fetchall() + if self._source_supports_scalars: + self.data = list(result._raw_row_iterator()) + else: + self.data = result.fetchall() + + def rewrite_rows(self): + if self._source_supports_scalars: + return [[elem] for elem in self.data] + else: + return [list(row) for row in self.data] - def with_data(self, data): + def with_new_rows(self, tuple_data): fr = FrozenResult.__new__(FrozenResult) fr.metadata = self.metadata fr._post_creational_filter = self._post_creational_filter - fr.data = data + fr._generate_rows = self._generate_rows + fr._attributes = self._attributes + fr._source_supports_scalars = self._source_supports_scalars + + if self._source_supports_scalars: + fr.data = [d[0] for d in tuple_data] + else: + fr.data = tuple_data return fr def __call__(self): result = IteratorResult(self.metadata, iter(self.data)) result._post_creational_filter = self._post_creational_filter - result._source_supports_scalars = self._source_supports_scalars result._generate_rows = self._generate_rows + result._attributes = self._attributes + result._source_supports_scalars = self._source_supports_scalars return result @@ -1143,9 +1180,10 @@ class IteratorResult(Result): """ - def __init__(self, cursor_metadata, iterator): + def __init__(self, cursor_metadata, iterator, raw=None): self._metadata = cursor_metadata self.iterator = iterator + self.raw = raw def _soft_close(self, **kw): self.iterator = iter([]) @@ -1189,28 +1227,23 @@ class ChunkedIteratorResult(IteratorResult): """ - def __init__(self, cursor_metadata, chunks, source_supports_scalars=False): + def __init__( + self, cursor_metadata, chunks, source_supports_scalars=False, raw=None + ): self._metadata = cursor_metadata self.chunks = chunks self._source_supports_scalars = source_supports_scalars - - self.iterator = itertools.chain.from_iterable( - self.chunks(None, self._generate_rows) - ) + self.raw = raw + self.iterator = itertools.chain.from_iterable(self.chunks(None)) def _column_slices(self, indexes): result = super(ChunkedIteratorResult, self)._column_slices(indexes) - self.iterator = itertools.chain.from_iterable( - self.chunks(self._yield_per, self._generate_rows) - ) return result @_generative def yield_per(self, num): self._yield_per = num - self.iterator = itertools.chain.from_iterable( - self.chunks(num, self._generate_rows) - ) + self.iterator = itertools.chain.from_iterable(self.chunks(num)) class MergedResult(IteratorResult): @@ -1238,8 +1271,14 @@ class MergedResult(IteratorResult): self._post_creational_filter = results[0]._post_creational_filter self._no_scalar_onerow = results[0]._no_scalar_onerow self._yield_per = results[0]._yield_per + + # going to try someting w/ this in next rev self._source_supports_scalars = results[0]._source_supports_scalars + self._generate_rows = results[0]._generate_rows + self._attributes = self._attributes.merge_with( + *[r._attributes for r in results] + ) def close(self): self._soft_close(hard=True) diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 24af454b67..112e245f78 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -19,7 +19,6 @@ from .. import exc as sa_exc from .. import util from ..orm import exc as orm_exc from ..orm import strategy_options -from ..orm.context import QueryContext from ..orm.query import Query from ..orm.session import Session from ..sql import func @@ -201,11 +200,12 @@ class BakedQuery(object): self.spoil(full=True) else: for opt in options: - cache_key = opt._generate_path_cache_key(cache_path) - if cache_key is False: - self.spoil(full=True) - elif cache_key is not None: - key += cache_key + if opt._is_legacy_option or opt._is_compile_state: + cache_key = opt._generate_path_cache_key(cache_path) + if cache_key is False: + self.spoil(full=True) + elif cache_key is not None: + key += cache_key self.add_criteria( lambda q: q._with_current_path(effective_path).options(*options), @@ -224,41 +224,32 @@ class BakedQuery(object): def _bake(self, session): query = self._as_query(session) + query.session = None - compile_state = query._compile_state() + # in 1.4, this is where before_compile() event is + # invoked + statement = query._statement_20(orm_results=True) - self._bake_subquery_loaders(session, compile_state) - - # TODO: compile_state clearly needs to be simplified here. - # if the session remains, fails memusage test - compile_state.orm_query = ( - query - ) = ( - compile_state.select_statement - ) = compile_state.query = compile_state.orm_query.with_session(None) - query._execution_options = query._execution_options.union( - {"compiled_cache": self._bakery} - ) - - # we'll be holding onto the query for some of its state, - # so delete some compilation-use-only attributes that can take up - # space - for attr in ( - "_correlate", - "_from_obj", - "_mapper_adapter_map", - "_joinpath", - "_joinpoint", - ): - query.__dict__.pop(attr, None) + # the before_compile() event can create a new Query object + # before it makes the statement. + query = statement.compile_options._orm_query # if the query is not safe to cache, we still do everything as though # we did cache it, since the receiver of _bake() assumes subqueryload # context was set up, etc. - if compile_state.compile_options._bake_ok: - self._bakery[self._effective_key(session)] = compile_state + # + # note also we want to cache the statement itself because this + # allows the statement itself to hold onto its cache key that is + # used by the Connection, which in itself is more expensive to + # generate than what BakedQuery was able to provide in 1.3 and prior + + if query.compile_options._bake_ok: + self._bakery[self._effective_key(session)] = ( + query, + statement, + ) - return compile_state + return query, statement def to_query(self, query_or_session): """Return the :class:`_query.Query` object for use as a subquery. @@ -321,50 +312,6 @@ class BakedQuery(object): return query - def _bake_subquery_loaders(self, session, compile_state): - """convert subquery eager loaders in the cache into baked queries. - - For subquery eager loading to work, all we need here is that the - Query point to the correct session when it is run. However, since - we are "baking" anyway, we may as well also turn the query into - a "baked" query so that we save on performance too. - - """ - compile_state.attributes["baked_queries"] = baked_queries = [] - for k, v in list(compile_state.attributes.items()): - if isinstance(v, dict) and "query" in v: - if "subqueryload_data" in k: - query = v["query"] - bk = BakedQuery(self._bakery, lambda *args: query) - bk._cache_key = self._cache_key + k - bk._bake(session) - baked_queries.append((k, bk._cache_key, v)) - del compile_state.attributes[k] - - def _unbake_subquery_loaders( - self, session, compile_state, context, params, post_criteria - ): - """Retrieve subquery eager loaders stored by _bake_subquery_loaders - and turn them back into Result objects that will iterate just - like a Query object. - - """ - if "baked_queries" not in compile_state.attributes: - return - - for k, cache_key, v in compile_state.attributes["baked_queries"]: - query = v["query"] - bk = BakedQuery( - self._bakery, lambda sess, q=query: q.with_session(sess) - ) - bk._cache_key = cache_key - q = bk.for_session(session) - for fn in post_criteria: - q = q.with_post_criteria(fn) - v = dict(v) - v["query"] = q.params(**params) - context.attributes[k] = v - class Result(object): """Invokes a :class:`.BakedQuery` against a :class:`.Session`. @@ -406,17 +353,19 @@ class Result(object): This adds a function that will be run against the :class:`_query.Query` object after it is retrieved from the - cache. Functions here can be used to alter the query in ways - that **do not affect the SQL output**, such as execution options - and shard identifiers (when using a shard-enabled query object) + cache. This currently includes **only** the + :meth:`_query.Query.params` and :meth:`_query.Query.execution_options` + methods. .. warning:: :meth:`_baked.Result.with_post_criteria` functions are applied to the :class:`_query.Query` object **after** the query's SQL statement - object has been retrieved from the cache. Any operations here - which intend to modify the SQL should ensure that - :meth:`.BakedQuery.spoil` was called first. + object has been retrieved from the cache. Only + :meth:`_query.Query.params` and + :meth:`_query.Query.execution_options` + methods should be used. + .. versionadded:: 1.2 @@ -438,40 +387,41 @@ class Result(object): def _iter(self): bq = self.bq + if not self.session.enable_baked_queries or bq._spoiled: return self._as_query()._iter() - baked_compile_state = bq._bakery.get( - bq._effective_key(self.session), None + query, statement = bq._bakery.get( + bq._effective_key(self.session), (None, None) ) - if baked_compile_state is None: - baked_compile_state = bq._bake(self.session) - - context = QueryContext(baked_compile_state, self.session) - context.session = self.session - - bq._unbake_subquery_loaders( - self.session, - baked_compile_state, - context, - self._params, - self._post_criteria, - ) - - # asserts true - # if isinstance(baked_compile_state.statement, expression.Select): - # assert baked_compile_state.statement._label_style == \ - # LABEL_STYLE_TABLENAME_PLUS_COL + if query is None: + query, statement = bq._bake(self.session) - if context.autoflush and not context.populate_existing: - self.session._autoflush() - q = context.orm_query.params(self._params).with_session(self.session) + q = query.params(self._params) for fn in self._post_criteria: q = fn(q) params = q.load_options._params + q.load_options += {"_orm_query": q} + execution_options = dict(q._execution_options) + execution_options.update( + { + "_sa_orm_load_options": q.load_options, + "compiled_cache": bq._bakery, + } + ) + + result = self.session.execute( + statement, params, execution_options=execution_options + ) + + if result._attributes.get("is_single_entity", False): + result = result.scalars() + + if result._attributes.get("filtered", False): + result = result.unique() - return q._execute_and_instances(context, params=params) + return result def count(self): """return the 'count'. @@ -583,10 +533,10 @@ class Result(object): query = self.bq.steps[0](self.session) return query._get_impl(ident, self._load_on_pk_identity) - def _load_on_pk_identity(self, query, primary_key_identity): + def _load_on_pk_identity(self, session, query, primary_key_identity, **kw): """Load the given primary key identity from the database.""" - mapper = query._only_full_mapper_zero("load_on_pk_identity") + mapper = query._raw_columns[0]._annotations["parententity"] _get_clause, _get_params = mapper._get_clause diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 919f4409a9..1375a24cd5 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -15,10 +15,8 @@ the source distribution. """ -import copy - +from sqlalchemy import event from .. import inspect -from .. import util from ..orm.query import Query from ..orm.session import Session @@ -37,54 +35,32 @@ class ShardedQuery(Query): all subsequent operations with the returned query will be against the single shard regardless of other state. - """ - q = self._clone() - q._shard_id = shard_id - return q + The shard_id can be passed for a 2.0 style execution to the + bind_arguments dictionary of :meth:`.Session.execute`:: - def _execute_and_instances(self, context, params=None): - if params is None: - params = self.load_options._params - - def iter_for_shard(shard_id): - # shallow copy, so that each context may be used by - # ORM load events and similar. - copied_context = copy.copy(context) - copied_context.attributes = context.attributes.copy() - - copied_context.attributes[ - "shard_id" - ] = copied_context.identity_token = shard_id - result_ = self._connection_from_session( - mapper=context.compile_state._bind_mapper(), shard_id=shard_id - ).execute( - copied_context.compile_state.statement, - self.load_options._params, + results = session.execute( + stmt, + bind_arguments={"shard_id": "my_shard"} ) - return self.instances(result_, copied_context) - if context.identity_token is not None: - return iter_for_shard(context.identity_token) - elif self._shard_id is not None: - return iter_for_shard(self._shard_id) - else: - partial = [] - for shard_id in self.query_chooser(self): - result_ = iter_for_shard(shard_id) - partial.append(result_) + """ - return partial[0].merge(*partial[1:]) + q = self._clone() + q._shard_id = shard_id + return q def _execute_crud(self, stmt, mapper): def exec_for_shard(shard_id): - conn = self._connection_from_session( + conn = self.session.connection( mapper=mapper, shard_id=shard_id, clause=stmt, close_with_result=True, ) - result = conn.execute(stmt, self.load_options._params) + result = conn._execute_20( + stmt, self.load_options._params, self._execution_options + ) return result if self._shard_id is not None: @@ -99,38 +75,6 @@ class ShardedQuery(Query): return ShardedResult(results, rowcount) - def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): - """Override the default Query._get_impl() method so that we emit - a query to the DB for each possible identity token, if we don't - have one already. - - """ - - def _db_load_fn(query, primary_key_identity): - # load from the database. The original db_load_fn will - # use the given Query object to load from the DB, so our - # shard_id is what will indicate the DB that we query from. - if self._shard_id is not None: - return db_load_fn(self, primary_key_identity) - else: - ident = util.to_list(primary_key_identity) - # build a ShardedQuery for each shard identifier and - # try to load from the DB - for shard_id in self.id_chooser(self, ident): - q = self.set_shard(shard_id) - o = db_load_fn(q, ident) - if o is not None: - return o - else: - return None - - if identity_token is None and self._shard_id is not None: - identity_token = self._shard_id - - return super(ShardedQuery, self)._get_impl( - primary_key_identity, _db_load_fn, identity_token=identity_token - ) - class ShardedResult(object): """A value object that represents multiple :class:`_engine.CursorResult` @@ -190,11 +134,14 @@ class ShardedSession(Session): """ super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) + + event.listen( + self, "do_orm_execute", execute_and_instances, retval=True + ) self.shard_chooser = shard_chooser self.id_chooser = id_chooser self.query_chooser = query_chooser self.__binds = {} - self.connection_callable = self.connection if shards is not None: for k in shards: self.bind_shard(k, shards[k]) @@ -207,8 +154,8 @@ class ShardedSession(Session): lazy_loaded_from=None, **kw ): - """override the default :meth:`.Session._identity_lookup` method so that we - search for a given non-token primary key identity across all + """override the default :meth:`.Session._identity_lookup` method so + that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from @@ -255,7 +202,14 @@ class ShardedSession(Session): state.identity_token = shard_id return shard_id - def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): + def connection_callable( + self, mapper=None, instance=None, shard_id=None, **kwargs + ): + """Provide a :class:`_engine.Connection` to use in the unit of work + flush process. + + """ + if shard_id is None: shard_id = self._choose_shard_and_assign(mapper, instance) @@ -267,7 +221,7 @@ class ShardedSession(Session): ).connect(**kwargs) def get_bind( - self, mapper, shard_id=None, instance=None, clause=None, **kw + self, mapper=None, shard_id=None, instance=None, clause=None, **kw ): if shard_id is None: shard_id = self._choose_shard_and_assign( @@ -277,3 +231,55 @@ class ShardedSession(Session): def bind_shard(self, shard_id, bind): self.__binds[shard_id] = bind + + +def execute_and_instances(orm_context): + if orm_context.bind_arguments.get("_horizontal_shard", False): + return None + + params = orm_context.parameters + + load_options = orm_context.load_options + session = orm_context.session + orm_query = orm_context.orm_query + + if params is None: + params = load_options._params + + def iter_for_shard(shard_id, load_options): + execution_options = dict(orm_context.execution_options) + + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["_horizontal_shard"] = True + bind_arguments["shard_id"] = shard_id + + load_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_load_options"] = load_options + + return session.execute( + orm_context.statement, + orm_context.parameters, + execution_options, + bind_arguments, + ) + + if load_options._refresh_identity_token is not None: + shard_id = load_options._refresh_identity_token + elif orm_query is not None and orm_query._shard_id is not None: + shard_id = orm_query._shard_id + elif "shard_id" in orm_context.bind_arguments: + shard_id = orm_context.bind_arguments["shard_id"] + else: + shard_id = None + + if shard_id is not None: + return iter_for_shard(shard_id, load_options) + else: + partial = [] + for shard_id in session.query_chooser( + orm_query if orm_query is not None else orm_context.statement + ): + result_ = iter_for_shard(shard_id, load_options) + partial.append(result_) + + return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py index 2b76245e0b..58fced8870 100644 --- a/lib/sqlalchemy/future/selectable.py +++ b/lib/sqlalchemy/future/selectable.py @@ -56,7 +56,9 @@ class Select(_LegacySelect): self = cls.__new__(cls) self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self) + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) for ent in entities ] @@ -71,9 +73,9 @@ class Select(_LegacySelect): def _filter_by_zero(self): if self._setup_joins: - meth = SelectState.get_plugin_classmethod( - self, "determine_last_joined_entity" - ) + meth = SelectState.get_plugin_class( + self + ).determine_last_joined_entity _last_joined_entity = meth(self) if _last_joined_entity is not None: return _last_joined_entity @@ -106,7 +108,7 @@ class Select(_LegacySelect): """ target = coercions.expect( - roles.JoinTargetRole, target, apply_plugins=self + roles.JoinTargetRole, target, apply_propagate_attrs=self ) self._setup_joins += ( (target, onclause, None, {"isouter": isouter, "full": full}), @@ -123,12 +125,15 @@ class Select(_LegacySelect): """ + # note the order of parsing from vs. target is important here, as we + # are also deriving the source of the plugin (i.e. the subject mapper + # in an ORM query) which should favor the "from_" over the "target" - target = coercions.expect( - roles.JoinTargetRole, target, apply_plugins=self - ) from_ = coercions.expect( - roles.FromClauseRole, from_, apply_plugins=self + roles.FromClauseRole, from_, apply_propagate_attrs=self + ) + target = coercions.expect( + roles.JoinTargetRole, target, apply_propagate_attrs=self ) self._setup_joins += ( diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 0a353f81c6..110c27811d 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -30,6 +30,7 @@ from .mapper import reconstructor # noqa from .mapper import validates # noqa from .properties import ColumnProperty # noqa from .query import AliasOption # noqa +from .query import FromStatement # noqa from .query import Query # noqa from .relationships import foreign # noqa from .relationships import RelationshipProperty # noqa @@ -39,8 +40,10 @@ from .session import close_all_sessions # noqa from .session import make_transient # noqa from .session import make_transient_to_detached # noqa from .session import object_session # noqa +from .session import ORMExecuteState # noqa from .session import Session # noqa from .session import sessionmaker # noqa +from .session import SessionTransaction # noqa from .strategy_options import Load # noqa from .util import aliased # noqa from .util import Bundle # noqa diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 7b4415bfe3..262a1efc91 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -207,6 +207,10 @@ class QueryableAttribute( def __clause_element__(self): return self.expression + @property + def _from_objects(self): + return self.expression._from_objects + def _bulk_update_tuples(self, value): """Return setter tuples for a bulk UPDATE.""" diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 0a37011340..3acab7df7d 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -18,19 +18,21 @@ from .util import Bundle from .util import join as orm_join from .util import ORMAdapter from .. import exc as sa_exc +from .. import future from .. import inspect from .. import sql from .. import util -from ..future.selectable import Select as FutureSelect from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors from ..sql.base import CacheableOptions +from ..sql.base import CompileState from ..sql.base import Options +from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY +from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL -from ..sql.selectable import Select from ..sql.selectable import SelectState from ..sql.visitors import ExtendedInternalTraversal from ..sql.visitors import InternalTraversal @@ -44,6 +46,8 @@ class QueryContext(object): "orm_query", "query", "load_options", + "bind_arguments", + "execution_options", "session", "autoflush", "populate_existing", @@ -51,7 +55,7 @@ class QueryContext(object): "version_check", "refresh_state", "create_eager_joins", - "propagate_options", + "propagated_loader_options", "attributes", "runid", "partials", @@ -70,20 +74,30 @@ class QueryContext(object): _yield_per = None _refresh_state = None _lazy_loaded_from = None + _orm_query = None _params = util.immutabledict() - def __init__(self, compile_state, session): - query = compile_state.query + def __init__( + self, + compile_state, + session, + load_options, + execution_options=None, + bind_arguments=None, + ): + self.load_options = load_options + self.execution_options = execution_options or {} + self.bind_arguments = bind_arguments or {} self.compile_state = compile_state self.orm_query = compile_state.orm_query - self.query = compile_state.query + self.query = query = compile_state.query self.session = session - self.load_options = load_options = query.load_options - self.propagate_options = set( + self.propagated_loader_options = { o for o in query._with_options if o.propagate_to_loaders - ) + } + self.attributes = dict(compile_state.attributes) self.autoflush = load_options._autoflush @@ -92,11 +106,7 @@ class QueryContext(object): self.version_check = load_options._version_check self.refresh_state = load_options._refresh_state self.yield_per = load_options._yield_per - - if self.refresh_state is not None: - self.identity_token = load_options._refresh_identity_token - else: - self.identity_token = None + self.identity_token = load_options._refresh_identity_token if self.yield_per and compile_state._no_yield_pers: raise sa_exc.InvalidRequestError( @@ -119,25 +129,10 @@ class QueryContext(object): ) -class QueryCompileState(sql.base.CompileState): - _joinpath = _joinpoint = util.immutabledict() - _from_obj_alias = None - _has_mapper_entities = False - - _has_orm_entities = False - multi_row_eager_loaders = False - compound_eager_adapter = None - loaders_require_buffering = False - loaders_require_uniquing = False - - correlate = None - _where_criteria = () - _having_criteria = () - - orm_query = None - +class ORMCompileState(CompileState): class default_compile_options(CacheableOptions): _cache_key_traversal = [ + ("_orm_results", InternalTraversal.dp_boolean), ("_bake_ok", InternalTraversal.dp_boolean), ( "_with_polymorphic_adapt_map", @@ -153,136 +148,310 @@ class QueryCompileState(sql.base.CompileState): ("_for_refresh_state", InternalTraversal.dp_boolean), ] + _orm_results = True _bake_ok = True _with_polymorphic_adapt_map = () _current_path = _path_registry _enable_single_crit = True - _statement = None _enable_eagerloads = True _orm_only_from_obj_alias = True _only_load_props = None _set_base_alias = False _for_refresh_state = False + # non-cache-key elements mostly for legacy use + _statement = None + _orm_query = None + + @classmethod + def merge(cls, other): + return cls + other._state_dict() + + orm_query = None + current_path = _path_registry + def __init__(self, *arg, **kw): raise NotImplementedError() @classmethod - def _create_for_select(cls, statement, compiler, **kw): - if not statement._is_future: - return SelectState(statement, compiler, **kw) + def create_for_statement(cls, statement_container, compiler, **kw): + raise NotImplementedError() - self = cls.__new__(cls) + @classmethod + def _create_for_legacy_query(cls, query, for_statement=False): + stmt = query._statement_20(orm_results=not for_statement) - if not isinstance( - statement.compile_options, cls.default_compile_options - ): - statement.compile_options = cls.default_compile_options - orm_state = self._create_for_legacy_query_via_either(statement) - compile_state = SelectState(orm_state.statement, compiler, **kw) - compile_state._orm_state = orm_state - return compile_state + if query.compile_options._statement is not None: + compile_state_cls = ORMFromStatementCompileState + else: + compile_state_cls = ORMSelectCompileState + + # true in all cases except for two tests in test/orm/test_events.py + # assert stmt.compile_options._orm_query is query + return compile_state_cls._create_for_statement_or_query( + stmt, for_statement=for_statement + ) @classmethod - def _create_future_select_from_query(cls, query): - stmt = FutureSelect.__new__(FutureSelect) - - # the internal state of Query is now a mirror of that of - # Select which can be transferred directly. The Select - # supports compilation into its correct form taking all ORM - # features into account via the plugin and the compile options. - # however it does not export its columns or other attributes - # correctly if deprecated ORM features that adapt plain mapped - # elements are used; for this reason the Select() returned here - # can always support direct execution, but for composition in a larger - # select only works if it does not represent legacy ORM adaption - # features. - stmt.__dict__.update( - dict( - _raw_columns=query._raw_columns, - _compile_state_plugin="orm", # ;) - _where_criteria=query._where_criteria, - _from_obj=query._from_obj, - _legacy_setup_joins=query._legacy_setup_joins, - _order_by_clauses=query._order_by_clauses, - _group_by_clauses=query._group_by_clauses, - _having_criteria=query._having_criteria, - _distinct=query._distinct, - _distinct_on=query._distinct_on, - _with_options=query._with_options, - _with_context_options=query._with_context_options, - _hints=query._hints, - _statement_hints=query._statement_hints, - _correlate=query._correlate, - _auto_correlate=query._auto_correlate, - _limit_clause=query._limit_clause, - _offset_clause=query._offset_clause, - _for_update_arg=query._for_update_arg, - _prefixes=query._prefixes, - _suffixes=query._suffixes, - _label_style=query._label_style, - compile_options=query.compile_options, - # this will be moving but for now make it work like orm.Query - load_options=query.load_options, + def _create_for_statement_or_query( + cls, statement_container, for_statement=False, + ): + raise NotImplementedError() + + @classmethod + def orm_pre_session_exec( + cls, session, statement, execution_options, bind_arguments + ): + if execution_options: + # TODO: will have to provide public API to set some load + # options and also extract them from that API here, likely + # execution options + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options ) + else: + load_options = QueryContext.default_load_options + + bind_arguments["clause"] = statement + + # new in 1.4 - the coercions system is leveraged to allow the + # "subject" mapper of a statement be propagated to the top + # as the statement is built. "subject" mapper is the generally + # standard object used as an identifier for multi-database schemes. + + if "plugin_subject" in statement._propagate_attrs: + bind_arguments["mapper"] = statement._propagate_attrs[ + "plugin_subject" + ].mapper + + if load_options._autoflush: + session._autoflush() + + @classmethod + def orm_setup_cursor_result(cls, session, bind_arguments, result): + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + # cover edge case where ORM entities used in legacy select + # were passed to session.execute: + # session.execute(legacy_select([User.id, User.name])) + # see test_query->test_legacy_tuple_old_select + if not execution_context.compiled.statement._is_future: + return result + + execution_options = execution_context.execution_options + + # we are getting these right above in orm_pre_session_exec(), + # then getting them again right here. + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + querycontext = QueryContext( + compile_state, + session, + load_options, + execution_options, + bind_arguments, ) + return loading.instances(result, querycontext) - return stmt + @property + def _mapper_entities(self): + return ( + ent for ent in self._entities if isinstance(ent, _MapperEntity) + ) + + def _create_with_polymorphic_adapter(self, ext_info, selectable): + if ( + not ext_info.is_aliased_class + and ext_info.mapper.persist_selectable + not in self._polymorphic_adapters + ): + self._mapper_loads_polymorphically_with( + ext_info.mapper, + sql_util.ColumnAdapter( + selectable, ext_info.mapper._equivalent_columns + ), + ) + + def _mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers or [mapper]: + self._polymorphic_adapters[m2] = adapter + for m in m2.iterate_to_root(): + self._polymorphic_adapters[m.local_table] = adapter + + +@sql.base.CompileState.plugin_for("orm", "grouping") +class ORMFromStatementCompileState(ORMCompileState): + _aliased_generations = util.immutabledict() + _from_obj_alias = None + _has_mapper_entities = False + + _has_orm_entities = False + multi_row_eager_loaders = False + compound_eager_adapter = None + loaders_require_buffering = False + loaders_require_uniquing = False + + @classmethod + def create_for_statement(cls, statement_container, compiler, **kw): + compiler._rewrites_selected_columns = True + return cls._create_for_statement_or_query(statement_container) @classmethod - def _create_for_legacy_query( - cls, query, for_statement=False, entities_only=False + def _create_for_statement_or_query( + cls, statement_container, for_statement=False, ): - # as we are seeking to use Select() with ORM state as the - # primary executable element, have all Query objects that are not - # from_statement() convert to a Select() first, then run on that. + # from .query import FromStatement - if query.compile_options._statement is not None: - return cls._create_for_legacy_query_via_either( - query, - for_statement=for_statement, - entities_only=entities_only, - orm_query=query, - ) + # assert isinstance(statement_container, FromStatement) + + self = cls.__new__(cls) + self._primary_entity = None + + self.orm_query = statement_container.compile_options._orm_query + + self.statement_container = self.query = statement_container + self.requested_statement = statement_container.element + + self._entities = [] + self._with_polymorphic_adapt_map = {} + self._polymorphic_adapters = {} + self._no_yield_pers = set() + + _QueryEntity.to_compile_state(self, statement_container._raw_columns) + + self.compile_options = statement_container.compile_options + + self.current_path = statement_container.compile_options._current_path + + if statement_container._with_options: + self.attributes = {"_unbound_load_dedupes": set()} + + for opt in statement_container._with_options: + if opt._is_compile_state: + opt.process_compile_state(self) + else: + self.attributes = {} + + if statement_container._with_context_options: + for fn, key in statement_container._with_context_options: + fn(self) + + self.primary_columns = [] + self.secondary_columns = [] + self.eager_joins = {} + self.single_inh_entities = {} + self.create_eager_joins = [] + self._fallback_from_clauses = [] + self._setup_for_statement() + + return self + + def _setup_for_statement(self): + statement = self.requested_statement + if ( + isinstance(statement, expression.SelectBase) + and not statement._is_textual + and not statement.use_labels + ): + self.statement = statement.apply_labels() else: - assert query.compile_options._statement is None + self.statement = statement + self.order_by = None - stmt = cls._create_future_select_from_query(query) + if isinstance(self.statement, expression.TextClause): + # setup for all entities. Currently, this is not useful + # for eager loaders, as the eager loaders that work are able + # to do their work entirely in row_processor. + for entity in self._entities: + entity.setup_compile_state(self) - return cls._create_for_legacy_query_via_either( - stmt, - for_statement=for_statement, - entities_only=entities_only, - orm_query=query, + # we did the setup just to get primary columns. + self.statement = expression.TextualSelect( + self.statement, self.primary_columns, positional=False ) + else: + # allow TextualSelect with implicit columns as well + # as select() with ad-hoc columns, see test_query::TextTest + self._from_obj_alias = sql.util.ColumnAdapter( + self.statement, adapt_on_names=True + ) + # set up for eager loaders, however if we fix subqueryload + # it should not need to do this here. the model of eager loaders + # that can work entirely in row_processor might be interesting + # here though subqueryloader has a lot of upfront work to do + # see test/orm/test_query.py -> test_related_eagerload_against_text + # for where this part makes a difference. would rather have + # subqueryload figure out what it needs more intelligently. + # for entity in self._entities: + # entity.setup_compile_state(self) + + def _adapt_col_list(self, cols, current_adapter): + return cols + + def _get_current_adapter(self): + return None + + +@sql.base.CompileState.plugin_for("orm", "select") +class ORMSelectCompileState(ORMCompileState, SelectState): + _joinpath = _joinpoint = util.immutabledict() + _from_obj_alias = None + _has_mapper_entities = False + + _has_orm_entities = False + multi_row_eager_loaders = False + compound_eager_adapter = None + loaders_require_buffering = False + loaders_require_uniquing = False + + correlate = None + _where_criteria = () + _having_criteria = () + + orm_query = None + + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + if not statement._is_future: + return SelectState(statement, compiler, **kw) + + compiler._rewrites_selected_columns = True + + orm_state = cls._create_for_statement_or_query( + statement, for_statement=True + ) + SelectState.__init__(orm_state, orm_state.statement, compiler, **kw) + return orm_state @classmethod - def _create_for_legacy_query_via_either( - cls, query, for_statement=False, entities_only=False, orm_query=None + def _create_for_statement_or_query( + cls, query, for_statement=False, _entities_only=False, ): + assert isinstance(query, future.Select) + + query.compile_options = cls.default_compile_options.merge( + query.compile_options + ) self = cls.__new__(cls) self._primary_entity = None - self.has_select = isinstance(query, Select) + self.orm_query = query.compile_options._orm_query - if orm_query: - self.orm_query = orm_query - self.query = query - self.has_orm_query = True - else: - self.query = query - if not self.has_select: - self.orm_query = query - self.has_orm_query = True - else: - self.orm_query = None - self.has_orm_query = False + self.query = query self.select_statement = select_statement = query + if not hasattr(select_statement.compile_options, "_orm_results"): + select_statement.compile_options = cls.default_compile_options + select_statement.compile_options += {"_orm_results": for_statement} + else: + for_statement = not select_statement.compile_options._orm_results + self.query = query self._entities = [] @@ -300,19 +469,28 @@ class QueryCompileState(sql.base.CompileState): _QueryEntity.to_compile_state(self, select_statement._raw_columns) - if entities_only: + if _entities_only: return self self.compile_options = query.compile_options + + # TODO: the name of this flag "for_statement" has to change, + # as it is difficult to distinguish from the "query._statement" use + # case which is something totally different self.for_statement = for_statement - if self.has_orm_query and not for_statement: - self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL + # determine label style. we can make different decisions here. + # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY + # rather than LABEL_STYLE_NONE, and if we can use disambiguate style + # for new style ORM selects too. + if self.select_statement._label_style is LABEL_STYLE_NONE: + if self.orm_query and not for_statement: + self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL + else: + self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY else: self.label_style = self.select_statement._label_style - self.labels = self.label_style is LABEL_STYLE_TABLENAME_PLUS_COL - self.current_path = select_statement.compile_options._current_path self.eager_order_by = () @@ -321,7 +499,7 @@ class QueryCompileState(sql.base.CompileState): self.attributes = {"_unbound_load_dedupes": set()} for opt in self.select_statement._with_options: - if not opt._is_legacy_option: + if opt._is_compile_state: opt.process_compile_state(self) else: self.attributes = {} @@ -341,13 +519,50 @@ class QueryCompileState(sql.base.CompileState): info.selectable for info in select_statement._from_obj ] - if self.compile_options._statement is not None: - self._setup_for_statement() - else: - self._setup_for_generate() + self._setup_for_generate() return self + @classmethod + def _create_entities_collection(cls, query): + """Creates a partial ORMSelectCompileState that includes + the full collection of _MapperEntity and other _QueryEntity objects. + + Supports a few remaining use cases that are pre-compilation + but still need to gather some of the column / adaption information. + + """ + self = cls.__new__(cls) + + self._entities = [] + self._primary_entity = None + self._aliased_generations = {} + self._polymorphic_adapters = {} + + # legacy: only for query.with_polymorphic() + self._with_polymorphic_adapt_map = wpam = dict( + query.compile_options._with_polymorphic_adapt_map + ) + if wpam: + self._setup_with_polymorphics() + + _QueryEntity.to_compile_state(self, query._raw_columns) + return self + + @classmethod + def determine_last_joined_entity(cls, statement): + setup_joins = statement._setup_joins + + if not setup_joins: + return None + + (target, onclause, from_, flags) = setup_joins[-1] + + if isinstance(target, interfaces.PropComparator): + return target.entity + else: + return target + def _setup_with_polymorphics(self): # legacy: only for query.with_polymorphic() for ext_info, wp in self._with_polymorphic_adapt_map.items(): @@ -404,34 +619,6 @@ class QueryCompileState(sql.base.CompileState): return None - def _deep_entity_zero(self): - """Return a 'deep' entity; this is any entity we can find associated - with the first entity / column experssion. this is used only for - session.get_bind(). - - it is hoped this concept can be removed in an upcoming change - to the ORM execution model. - - """ - for ent in self.from_clauses: - if "parententity" in ent._annotations: - return ent._annotations["parententity"].mapper - for ent in self._entities: - ezero = ent._deep_entity_zero() - if ezero is not None: - return ezero.mapper - else: - return None - - @property - def _mapper_entities(self): - for ent in self._entities: - if isinstance(ent, _MapperEntity): - yield ent - - def _bind_mapper(self): - return self._deep_entity_zero() - def _only_full_mapper_zero(self, methname): if self._entities != [self._primary_entity]: raise sa_exc.InvalidRequestError( @@ -490,7 +677,7 @@ class QueryCompileState(sql.base.CompileState): else query._order_by_clauses ) - if query._having_criteria is not None: + if query._having_criteria: self._having_criteria = tuple( current_adapter(crit, True, True) if current_adapter else crit for crit in query._having_criteria @@ -527,7 +714,7 @@ class QueryCompileState(sql.base.CompileState): for s in query._correlate ) ) - elif self.has_select and not query._auto_correlate: + elif not query._auto_correlate: self.correlate = (None,) # PART II @@ -582,33 +769,6 @@ class QueryCompileState(sql.base.CompileState): {"deepentity": ezero} ) - def _setup_for_statement(self): - compile_options = self.compile_options - - if ( - isinstance(compile_options._statement, expression.SelectBase) - and not compile_options._statement._is_textual - and not compile_options._statement.use_labels - ): - self.statement = compile_options._statement.apply_labels() - else: - self.statement = compile_options._statement - self.order_by = None - - if isinstance(self.statement, expression.TextClause): - # setup for all entities, including contains_eager entities. - for entity in self._entities: - entity.setup_compile_state(self) - self.statement = expression.TextualSelect( - self.statement, self.primary_columns, positional=False - ) - else: - # allow TextualSelect with implicit columns as well - # as select() with ad-hoc columns, see test_query::TextTest - self._from_obj_alias = sql.util.ColumnAdapter( - self.statement, adapt_on_names=True - ) - def _compound_eager_statement(self): # for eager joins present and LIMIT/OFFSET/DISTINCT, # wrap the query inside a select, @@ -659,9 +819,10 @@ class QueryCompileState(sql.base.CompileState): self.compound_eager_adapter = sql_util.ColumnAdapter(inner, equivs) - statement = sql.select( - [inner] + self.secondary_columns, use_labels=self.labels + statement = future.select( + *([inner] + self.secondary_columns) # use_labels=self.labels ) + statement._label_style = self.label_style # Oracle however does not allow FOR UPDATE on the subquery, # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL @@ -752,6 +913,7 @@ class QueryCompileState(sql.base.CompileState): group_by, ): + Select = future.Select statement = Select.__new__(Select) statement._raw_columns = raw_columns statement._from_obj = from_obj @@ -794,25 +956,6 @@ class QueryCompileState(sql.base.CompileState): return statement - def _create_with_polymorphic_adapter(self, ext_info, selectable): - if ( - not ext_info.is_aliased_class - and ext_info.mapper.persist_selectable - not in self._polymorphic_adapters - ): - self._mapper_loads_polymorphically_with( - ext_info.mapper, - sql_util.ColumnAdapter( - selectable, ext_info.mapper._equivalent_columns - ), - ) - - def _mapper_loads_polymorphically_with(self, mapper, adapter): - for m2 in mapper._with_polymorphic_mappers or [mapper]: - self._polymorphic_adapters[m2] = adapter - for m in m2.iterate_to_root(): - self._polymorphic_adapters[m.local_table] = adapter - def _adapt_polymorphic_element(self, element): if "parententity" in element._annotations: search = element._annotations["parententity"] @@ -924,6 +1067,8 @@ class QueryCompileState(sql.base.CompileState): # onclause = right right = None + elif "parententity" in right._annotations: + right = right._annotations["parententity"].entity if onclause is None: r_info = inspect(right) @@ -932,7 +1077,6 @@ class QueryCompileState(sql.base.CompileState): "Expected mapped entity or " "selectable/table as join target" ) - if isinstance(onclause, interfaces.PropComparator): of_type = getattr(onclause, "_of_type", None) else: @@ -1584,7 +1728,7 @@ class QueryCompileState(sql.base.CompileState): "aliased_generation": aliased_generation, } - return right, inspect(right), onclause + return inspect(right), right, onclause def _update_joinpoint(self, jp): self._joinpoint = jp @@ -1668,14 +1812,8 @@ class QueryCompileState(sql.base.CompileState): def _column_descriptions(query_or_select_stmt): - # TODO: this is a hack for now, as it is a little bit non-performant - # to build up QueryEntity for every entity right now. - ctx = QueryCompileState._create_for_legacy_query_via_either( - query_or_select_stmt, - entities_only=True, - orm_query=query_or_select_stmt - if not isinstance(query_or_select_stmt, Select) - else None, + ctx = ORMSelectCompileState._create_entities_collection( + query_or_select_stmt ) return [ { @@ -1731,23 +1869,6 @@ def _entity_from_pre_ent_zero(query_or_augmented_select): return ent -@sql.base.CompileState.plugin_for( - "orm", "select", "determine_last_joined_entity" -) -def _determine_last_joined_entity(statement): - setup_joins = statement._setup_joins - - if not setup_joins: - return None - - (target, onclause, from_, flags) = setup_joins[-1] - - if isinstance(target, interfaces.PropComparator): - return target.entity - else: - return target - - def _legacy_determine_last_joined_entity(setup_joins, entity_zero): """given the legacy_setup_joins collection at a point in time, figure out what the "filter by entity" would be in terms @@ -1929,9 +2050,6 @@ class _MapperEntity(_QueryEntity): def entity_zero_or_selectable(self): return self.entity_zero - def _deep_entity_zero(self): - return self.entity_zero - def corresponds_to(self, entity): return _entity_corresponds_to(self.entity_zero, entity) @@ -2093,14 +2211,6 @@ class _BundleEntity(_QueryEntity): else: return None - def _deep_entity_zero(self): - for ent in self._entities: - ezero = ent._deep_entity_zero() - if ezero is not None: - return ezero - else: - return None - def setup_compile_state(self, compile_state): for ent in self._entities: ent.setup_compile_state(compile_state) @@ -2175,17 +2285,6 @@ class _RawColumnEntity(_ColumnEntity): ) self._extra_entities = (self.expr, self.column) - def _deep_entity_zero(self): - for obj in visitors.iterate( - self.column, {"column_tables": True, "column_collections": False}, - ): - if "parententity" in obj._annotations: - return obj._annotations["parententity"] - elif "deepentity" in obj._annotations: - return obj._annotations["deepentity"] - else: - return None - def corresponds_to(self, entity): return False @@ -2276,9 +2375,6 @@ class _ORMColumnEntity(_ColumnEntity): ezero, ezero.selectable ) - def _deep_entity_zero(self): - return self.mapper - def corresponds_to(self, entity): if _is_aliased_class(entity): # TODO: polymorphic subclasses ? @@ -2342,8 +2438,3 @@ class _ORMColumnEntity(_ColumnEntity): compile_state.primary_columns.append(column) compile_state.attributes[("fetch_column", self)] = column - - -sql.base.CompileState.plugin_for("orm", "select")( - QueryCompileState._create_for_select -) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index f5d1918603..be7aa272ea 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1397,6 +1397,43 @@ class SessionEvents(event.Events): event_key.base_listen(**kw) + def do_orm_execute(self, orm_execute_state): + """Intercept statement executions that occur in terms of a :class:`.Session`. + + This event is invoked for all top-level SQL statements invoked + from the :meth:`_orm.Session.execute` method. As of SQLAlchemy 1.4, + all ORM queries emitted on behalf of a :class:`_orm.Session` will + flow through this method, so this event hook provides the single + point at which ORM queries of all types may be intercepted before + they are invoked, and additionally to replace their execution with + a different process. + + This event is a ``do_`` event, meaning it has the capability to replace + the operation that the :meth:`_orm.Session.execute` method normally + performs. The intended use for this includes sharding and + result-caching schemes which may seek to invoke the same statement + across multiple database connections, returning a result that is + merged from each of them, or which don't invoke the statement at all, + instead returning data from a cache. + + The hook intends to replace the use of the + ``Query._execute_and_instances`` method that could be subclassed prior + to SQLAlchemy 1.4. + + :param orm_execute_state: an instance of :class:`.ORMExecuteState` + which contains all information about the current execution, as well + as helper functions used to derive other commonly required + information. See that object for details. + + .. seealso:: + + :class:`.ORMExecuteState` + + + .. versionadded:: 1.4 + + """ + def after_transaction_create(self, session, transaction): """Execute when a new :class:`.SessionTransaction` is created. diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 313f2fda8d..6c0f5d3ef4 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -64,6 +64,12 @@ __all__ = ( ) +class ORMStatementRole(roles.CoerceTextStatementRole): + _role_name = ( + "Executable SQL or text() construct, including ORM " "aware objects" + ) + + class ORMColumnsClauseRole(roles.ColumnsClauseRole): _role_name = "ORM mapped entity, aliased entity, or Column expression" @@ -662,8 +668,15 @@ class StrategizedProperty(MapperProperty): ) -class LoaderOption(HasCacheKey): - """Describe a modification to an ORM statement at compilation time. +class ORMOption(object): + """Base class for option objects that are passed to ORM queries. + + These options may be consumed by :meth:`.Query.options`, + :meth:`.Select.options`, or in a more general sense by any + :meth:`.Executable.options` method. They are interpreted at + statement compile time or execution time in modern use. The + deprecated :class:`.MapperOption` is consumed at ORM query construction + time. .. versionadded:: 1.4 @@ -680,6 +693,18 @@ class LoaderOption(HasCacheKey): """ + _is_compile_state = False + + +class LoaderOption(HasCacheKey, ORMOption): + """Describe a loader modification to an ORM statement at compilation time. + + .. versionadded:: 1.4 + + """ + + _is_compile_state = True + def process_compile_state(self, compile_state): """Apply a modification to a given :class:`.CompileState`.""" @@ -693,18 +718,39 @@ class LoaderOption(HasCacheKey): return False +class UserDefinedOption(ORMOption): + """Base class for a user-defined option that can be consumed from the + :meth:`.SessionEvents.do_orm_execute` event hook. + + """ + + _is_legacy_option = False + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" Query objects produced during lazy loads + or refresh operations. + + """ + + def __init__(self, payload=None): + self.payload = payload + + def _gen_cache_key(self, *arg, **kw): + return () + + @util.deprecated_cls( "1.4", "The :class:`.MapperOption class is deprecated and will be removed " - "in a future release. ORM options now run within the compilation " - "phase and are based on the :class:`.LoaderOption` class which is " - "intended for internal consumption only. For " - "modifications to queries on a per-execution basis, the " - ":meth:`.before_execute` hook will now intercept ORM :class:`.Query` " - "objects before they are invoked", + "in a future release. For " + "modifications to queries on a per-execution basis, use the " + ":class:`.UserDefinedOption` class to establish state within a " + ":class:`.Query` or other Core statement, then use the " + ":meth:`.SessionEvents.before_orm_execute` hook to consume them.", constructor=None, ) -class MapperOption(object): +class MapperOption(ORMOption): """Describe a modification to a Query""" _is_legacy_option = True @@ -735,23 +781,6 @@ class MapperOption(object): def _generate_path_cache_key(self, path): """Used by the "baked lazy loader" to see if this option can be cached. - The "baked lazy loader" refers to the :class:`_query.Query` that is - produced during a lazy load operation for a mapped relationship. - It does not yet apply to the "lazy" load operation for deferred - or expired column attributes, however this may change in the future. - - This loader generates SQL for a query only once and attempts to cache - it; from that point on, if the SQL has been cached it will no longer - run the :meth:`_query.Query.options` method of the - :class:`_query.Query`. The - :class:`.MapperOption` object that wishes to participate within a lazy - load operation therefore needs to tell the baked loader that it either - needs to forego this caching, or that it needs to include the state of - the :class:`.MapperOption` itself as part of its cache key, otherwise - SQL or other query state that has been affected by the - :class:`.MapperOption` may be cached in place of a query that does not - include these modifications, or the option may not be invoked at all. - By default, this method returns the value ``False``, which means the :class:`.BakedQuery` generated by the lazy loader will not cache the SQL when this :class:`.MapperOption` is present. @@ -760,26 +789,10 @@ class MapperOption(object): an unlimited number of :class:`_query.Query` objects for an unlimited number of :class:`.MapperOption` objects. - .. versionchanged:: 1.2.8 the default return value of - :meth:`.MapperOption._generate_cache_key` is False; previously it - was ``None`` indicating "safe to cache, don't include as part of - the cache key" - - To enable caching of :class:`_query.Query` objects within lazy loaders - , a - given :class:`.MapperOption` that returns a cache key must return a key - that uniquely identifies the complete state of this option, which will - match any other :class:`.MapperOption` that itself retains the - identical state. This includes path options, flags, etc. It should - be a state that is repeatable and part of a limited set of possible - options. - - If the :class:`.MapperOption` does not apply to the given path and - would not affect query results on such a path, it should return None, - indicating the :class:`_query.Query` is safe to cache for this given - loader path and that this :class:`.MapperOption` need not be - part of the cache key. - + For caching support it is recommended to use the + :class:`.UserDefinedOption` class in conjunction with + the :meth:`.Session.do_orm_execute` method so that statements may + be modified before they are cached. """ return False diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 48641685e3..616e757a39 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -26,6 +26,7 @@ from .base import _SET_DEFERRED_EXPIRED from .util import _none_set from .util import state_str from .. import exc as sa_exc +from .. import future from .. import util from ..engine import result_tuple from ..engine.result import ChunkedIteratorResult @@ -36,8 +37,20 @@ from ..sql import util as sql_util _new_runid = util.counter() -def instances(query, cursor, context): - """Return an ORM result as an iterator.""" +def instances(cursor, context): + """Return a :class:`.Result` given an ORM query context. + + :param cursor: a :class:`.CursorResult`, generated by a statement + which came from :class:`.ORMCompileState` + + :param context: a :class:`.QueryContext` object + + :return: a :class:`.Result` object representing ORM results + + .. versionchanged:: 1.4 The instances() function now uses + :class:`.Result` objects and has an all new interface. + + """ context.runid = _new_runid() context.post_load_paths = {} @@ -80,7 +93,7 @@ def instances(query, cursor, context): ], ) - def chunks(size, as_tuples): + def chunks(size): while True: yield_per = size @@ -94,7 +107,7 @@ def instances(query, cursor, context): else: fetch = cursor.fetchall() - if not as_tuples: + if single_entity: proc = process[0] rows = [proc(row) for row in fetch] else: @@ -111,20 +124,62 @@ def instances(query, cursor, context): break result = ChunkedIteratorResult( - row_metadata, chunks, source_supports_scalars=single_entity + row_metadata, chunks, source_supports_scalars=single_entity, raw=cursor + ) + + result._attributes = result._attributes.union( + dict(filtered=filtered, is_single_entity=single_entity) ) + if context.yield_per: result.yield_per(context.yield_per) - if single_entity: - result = result.scalars() + return result - filtered = context.compile_state._has_mapper_entities - if filtered: - result = result.unique() +@util.preload_module("sqlalchemy.orm.context") +def merge_frozen_result(session, statement, frozen_result, load=True): + querycontext = util.preloaded.orm_context - return result + if load: + # flush current contents if we expect to load data + session._autoflush() + + ctx = querycontext.ORMSelectCompileState._create_entities_collection( + statement + ) + + autoflush = session.autoflush + try: + session.autoflush = False + mapped_entities = [ + i + for i, e in enumerate(ctx._entities) + if isinstance(e, querycontext._MapperEntity) + ] + keys = [ent._label_name for ent in ctx._entities] + + keyed_tuple = result_tuple( + keys, [ent._extra_entities for ent in ctx._entities] + ) + + result = [] + for newrow in frozen_result.rewrite_rows(): + for i in mapped_entities: + if newrow[i] is not None: + newrow[i] = session._merge( + attributes.instance_state(newrow[i]), + attributes.instance_dict(newrow[i]), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + + result.append(keyed_tuple(newrow)) + + return frozen_result.with_new_rows(result) + finally: + session.autoflush = autoflush @util.preload_module("sqlalchemy.orm.context") @@ -145,9 +200,7 @@ def merge_result(query, iterator, load=True): else: frozen_result = None - ctx = querycontext.QueryCompileState._create_for_legacy_query( - query, entities_only=True - ) + ctx = querycontext.ORMSelectCompileState._create_entities_collection(query) autoflush = session.autoflush try: @@ -235,12 +288,15 @@ def get_from_identity(session, mapper, key, passive): def load_on_ident( - query, + session, + statement, key, + load_options=None, refresh_state=None, with_for_update=None, only_load_props=None, no_autoflush=False, + bind_arguments=util.immutabledict(), ): """Load the given identity key from the database.""" if key is not None: @@ -249,38 +305,59 @@ def load_on_ident( else: ident = identity_token = None - if no_autoflush: - query = query.autoflush(False) - return load_on_pk_identity( - query, + session, + statement, ident, + load_options=load_options, refresh_state=refresh_state, with_for_update=with_for_update, only_load_props=only_load_props, identity_token=identity_token, + no_autoflush=no_autoflush, + bind_arguments=bind_arguments, ) def load_on_pk_identity( - query, + session, + statement, primary_key_identity, + load_options=None, refresh_state=None, with_for_update=None, only_load_props=None, identity_token=None, + no_autoflush=False, + bind_arguments=util.immutabledict(), ): """Load the given primary key identity from the database.""" + query = statement + q = query._clone() + + # TODO: fix these imports .... + from .context import QueryContext, ORMCompileState + + if load_options is None: + load_options = QueryContext.default_load_options + + compile_options = ORMCompileState.default_compile_options.merge( + q.compile_options + ) + + # checking that query doesnt have criteria on it + # just delete it here w/ optional assertion? since we are setting a + # where clause also if refresh_state is None: - q = query._clone() - q._get_condition() - else: - q = query._clone() + _no_criterion_assertion(q, "get", order_by=False, distinct=False) if primary_key_identity is not None: - mapper = query._only_full_mapper_zero("load_on_pk_identity") + # mapper = query._only_full_mapper_zero("load_on_pk_identity") + + # TODO: error checking? + mapper = query._raw_columns[0]._annotations["parententity"] (_get_clause, _get_params) = mapper._get_clause @@ -320,9 +397,8 @@ def load_on_pk_identity( ] ) - q.load_options += {"_params": params} + load_options += {"_params": params} - # with_for_update needs to be query.LockmodeArg() if with_for_update is not None: version_check = True q._for_update_arg = with_for_update @@ -333,11 +409,15 @@ def load_on_pk_identity( version_check = False if refresh_state and refresh_state.load_options: - # if refresh_state.load_path.parent: - q = q._with_current_path(refresh_state.load_path.parent) - q = q.options(refresh_state.load_options) + compile_options += {"_current_path": refresh_state.load_path.parent} + q = q.options(*refresh_state.load_options) - q._get_options( + # TODO: most of the compile_options that are not legacy only involve this + # function, so try to see if handling of them can mostly be local to here + + q.compile_options, load_options = _set_get_options( + compile_options, + load_options, populate_existing=bool(refresh_state), version_check=version_check, only_load_props=only_load_props, @@ -346,12 +426,76 @@ def load_on_pk_identity( ) q._order_by = None + if no_autoflush: + load_options += {"_autoflush": False} + + result = ( + session.execute( + q, + params=load_options._params, + execution_options={"_sa_orm_load_options": load_options}, + bind_arguments=bind_arguments, + ) + .unique() + .scalars() + ) + try: - return q.one() + return result.one() except orm_exc.NoResultFound: return None +def _no_criterion_assertion(stmt, meth, order_by=True, distinct=True): + if ( + stmt._where_criteria + or stmt.compile_options._statement is not None + or stmt._from_obj + or stmt._legacy_setup_joins + or stmt._limit_clause is not None + or stmt._offset_clause is not None + or stmt._group_by_clauses + or (order_by and stmt._order_by_clauses) + or (distinct and stmt._distinct) + ): + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth + ) + + +def _set_get_options( + compile_opt, + load_opt, + populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None, + identity_token=None, +): + + compile_options = {} + load_options = {} + if version_check: + load_options["_version_check"] = version_check + if populate_existing: + load_options["_populate_existing"] = populate_existing + if refresh_state: + load_options["_refresh_state"] = refresh_state + compile_options["_for_refresh_state"] = True + if only_load_props: + compile_options["_only_load_props"] = frozenset(only_load_props) + if identity_token: + load_options["_refresh_identity_token"] = identity_token + + if load_options: + load_opt += load_options + if compile_options: + compile_opt += compile_options + + return compile_opt, load_opt + + def _setup_entity_query( compile_state, mapper, @@ -487,7 +631,7 @@ def _instance_processor( context, path, mapper, result, adapter, populators ) - propagate_options = context.propagate_options + propagated_loader_options = context.propagated_loader_options load_path = ( context.compile_state.current_path + path if context.compile_state.current_path.path @@ -639,8 +783,8 @@ def _instance_processor( # be conservative about setting load_path when populate_existing # is in effect; want to maintain options from the original # load. see test_expire->test_refresh_maintains_deferred_options - if isnew and (propagate_options or not populate_existing): - state.load_options = propagate_options + if isnew and (propagated_loader_options or not populate_existing): + state.load_options = propagated_loader_options state.load_path = load_path _populate_full( @@ -1055,7 +1199,7 @@ def load_scalar_attributes(mapper, state, attribute_names, passive): result = False - no_autoflush = passive & attributes.NO_AUTOFLUSH + no_autoflush = bool(passive & attributes.NO_AUTOFLUSH) # in the case of inheritance, particularly concrete and abstract # concrete inheritance, the class manager might have some keys @@ -1080,10 +1224,16 @@ def load_scalar_attributes(mapper, state, attribute_names, passive): # note: using from_statement() here means there is an adaption # with adapt_on_names set up. the other option is to make the # aliased() against a subquery which affects the SQL. + + from .query import FromStatement + + stmt = FromStatement(mapper, statement).options( + strategy_options.Load(mapper).undefer("*") + ) + result = load_on_ident( - session.query(mapper) - .options(strategy_options.Load(mapper).undefer("*")) - .from_statement(statement), + session, + stmt, None, only_load_props=attribute_names, refresh_state=state, @@ -1121,7 +1271,8 @@ def load_scalar_attributes(mapper, state, attribute_names, passive): return result = load_on_ident( - session.query(mapper), + session, + future.select(mapper).apply_labels(), identity_key, refresh_state=state, only_load_props=attribute_names, diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a6fb1039fc..7bfe70c36b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2237,6 +2237,8 @@ class Mapper( "parentmapper": self, "compile_state_plugin": "orm", } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} ) @property diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 1698a51819..2e59417132 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -228,10 +228,29 @@ class RootRegistry(PathRegistry): PathRegistry.root = RootRegistry() +class PathToken(HasCacheKey, str): + """cacheable string token""" + + _intern = {} + + def _gen_cache_key(self, anon_map, bindparams): + return (str(self),) + + @classmethod + def intern(cls, strvalue): + if strvalue in cls._intern: + return cls._intern[strvalue] + else: + cls._intern[strvalue] = result = PathToken(strvalue) + return result + + class TokenRegistry(PathRegistry): __slots__ = ("token", "parent", "path", "natural_path") def __init__(self, parent, token): + token = PathToken.intern(token) + self.token = token self.parent = parent self.path = parent.path + (token,) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index d14f6c27b9..163ebf22a5 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -25,6 +25,7 @@ from . import loading from . import sync from .base import state_str from .. import exc as sa_exc +from .. import future from .. import sql from .. import util from ..sql import coercions @@ -1424,8 +1425,10 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): if toload_now: state.key = base_mapper._identity_key_from_state(state) + stmt = future.select(mapper).apply_labels() loading.load_on_ident( - uowtransaction.session.query(mapper), + uowtransaction.session, + stmt, state.key, refresh_state=state, only_load_props=toload_now, @@ -1723,7 +1726,7 @@ class BulkUD(object): self.context ) = compile_state = query._compile_state() - self.mapper = compile_state._bind_mapper() + self.mapper = compile_state._entity_zero() if isinstance( compile_state._entities[0], query_context._RawColumnEntity, diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 027786c190..4cf501e3f3 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -346,14 +346,20 @@ class ColumnProperty(StrategizedProperty): pe = self._parententity # no adapter, so we aren't aliased # assert self._parententity is self._parentmapper - return self.prop.columns[0]._annotate( - { - "entity_namespace": pe, - "parententity": pe, - "parentmapper": pe, - "orm_key": self.prop.key, - "compile_state_plugin": "orm", - } + return ( + self.prop.columns[0] + ._annotate( + { + "entity_namespace": pe, + "parententity": pe, + "parentmapper": pe, + "orm_key": self.prop.key, + "compile_state_plugin": "orm", + } + ) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": pe} + ) ) def _memoized_attr_info(self): @@ -388,6 +394,11 @@ class ColumnProperty(StrategizedProperty): "orm_key": self.prop.key, "compile_state_plugin": "orm", } + )._set_propagate_attrs( + { + "compile_state_plugin": "orm", + "plugin_subject": self._parententity, + } ) for col in self.prop.columns ] diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8a861c3dc3..25d6f47361 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -18,6 +18,7 @@ ORM session, whereas the ``Select`` construct interacts directly with the database to return iterable result sets. """ +import itertools from . import attributes from . import exc as orm_exc @@ -28,7 +29,8 @@ from .base import _assertions from .context import _column_descriptions from .context import _legacy_determine_last_joined_entity from .context import _legacy_filter_by_entity_zero -from .context import QueryCompileState +from .context import ORMCompileState +from .context import ORMFromStatementCompileState from .context import QueryContext from .interfaces import ORMColumnsClauseRole from .util import aliased @@ -42,18 +44,22 @@ from .. import inspection from .. import log from .. import sql from .. import util +from ..future.selectable import Select as FutureSelect from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util +from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _generative from ..sql.base import Executable +from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectStatementGrouping from ..sql.util import _entity_namespace_key from ..util import collections_abc @@ -62,7 +68,15 @@ __all__ = ["Query", "QueryContext", "aliased"] @inspection._self_inspects @log.class_logger -class Query(HasPrefixes, HasSuffixes, HasHints, Executable): +class Query( + _SelectFromElements, + SupportsCloneAnnotations, + HasPrefixes, + HasSuffixes, + HasHints, + Executable, +): + """ORM-level SQL construction object. :class:`_query.Query` @@ -105,7 +119,7 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): _legacy_setup_joins = () _label_style = LABEL_STYLE_NONE - compile_options = QueryCompileState.default_compile_options + compile_options = ORMCompileState.default_compile_options load_options = QueryContext.default_load_options @@ -115,6 +129,11 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): _enable_assertions = True _last_joined_entity = None + # mirrors that of ClauseElement, used to propagate the "orm" + # plugin as well as the "subject" of the plugin, e.g. the mapper + # we are querying against. + _propagate_attrs = util.immutabledict() + def __init__(self, entities, session=None): """Construct a :class:`_query.Query` directly. @@ -148,7 +167,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): def _set_entities(self, entities): self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, ent) + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) for ent in util.to_list(entities) ] @@ -183,7 +204,10 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): def _set_select_from(self, obj, set_base_alias): fa = [ coercions.expect( - roles.StrictFromClauseRole, elem, allow_select=True + roles.StrictFromClauseRole, + elem, + allow_select=True, + apply_propagate_attrs=self, ) for elem in obj ] @@ -332,15 +356,13 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): if ( not self.compile_options._set_base_alias and not self.compile_options._with_polymorphic_adapt_map - and self.compile_options._statement is None + # and self.compile_options._statement is None ): # if we don't have legacy top level aliasing features in use # then convert to a future select() directly stmt = self._statement_20() else: - stmt = QueryCompileState._create_for_legacy_query( - self, for_statement=True - ).statement + stmt = self._compile_state(for_statement=True).statement if self.load_options._params: # this is the search and replace thing. this is kind of nuts @@ -349,8 +371,67 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): return stmt - def _statement_20(self): - return QueryCompileState._create_future_select_from_query(self) + def _statement_20(self, orm_results=False): + # TODO: this event needs to be deprecated, as it currently applies + # only to ORM query and occurs at this spot that is now more + # or less an artificial spot + if self.dispatch.before_compile: + for fn in self.dispatch.before_compile: + new_query = fn(self) + if new_query is not None and new_query is not self: + self = new_query + if not fn._bake_ok: + self.compile_options += {"_bake_ok": False} + + if self.compile_options._statement is not None: + stmt = FromStatement( + self._raw_columns, self.compile_options._statement + ) + # TODO: once SubqueryLoader uses select(), we can remove + # "_orm_query" from this structure + stmt.__dict__.update( + _with_options=self._with_options, + _with_context_options=self._with_context_options, + compile_options=self.compile_options + + {"_orm_query": self.with_session(None)}, + _execution_options=self._execution_options, + ) + stmt._propagate_attrs = self._propagate_attrs + else: + stmt = FutureSelect.__new__(FutureSelect) + + stmt.__dict__.update( + _raw_columns=self._raw_columns, + _where_criteria=self._where_criteria, + _from_obj=self._from_obj, + _legacy_setup_joins=self._legacy_setup_joins, + _order_by_clauses=self._order_by_clauses, + _group_by_clauses=self._group_by_clauses, + _having_criteria=self._having_criteria, + _distinct=self._distinct, + _distinct_on=self._distinct_on, + _with_options=self._with_options, + _with_context_options=self._with_context_options, + _hints=self._hints, + _statement_hints=self._statement_hints, + _correlate=self._correlate, + _auto_correlate=self._auto_correlate, + _limit_clause=self._limit_clause, + _offset_clause=self._offset_clause, + _for_update_arg=self._for_update_arg, + _prefixes=self._prefixes, + _suffixes=self._suffixes, + _label_style=self._label_style, + compile_options=self.compile_options + + {"_orm_query": self.with_session(None)}, + _execution_options=self._execution_options, + ) + + if not orm_results: + stmt.compile_options += {"_orm_results": False} + + stmt._propagate_attrs = self._propagate_attrs + return stmt def subquery(self, name=None, with_labels=False, reduce_columns=False): """return the full SELECT statement represented by @@ -879,7 +960,17 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): elif instance is attributes.PASSIVE_CLASS_MISMATCH: return None - return db_load_fn(self, primary_key_identity) + # apply_labels() not strictly necessary, however this will ensure that + # tablename_colname style is used which at the moment is asserted + # in a lot of unit tests :) + + statement = self._statement_20(orm_results=True).apply_labels() + return db_load_fn( + self.session, + statement, + primary_key_identity, + load_options=self.load_options, + ) @property def lazy_loaded_from(self): @@ -1059,7 +1150,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): self._raw_columns = list(self._raw_columns) self._raw_columns.append( - coercions.expect(roles.ColumnsClauseRole, entity) + coercions.expect( + roles.ColumnsClauseRole, entity, apply_propagate_attrs=self + ) ) @_generative @@ -1397,7 +1490,10 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): self._raw_columns = list(self._raw_columns) self._raw_columns.extend( - coercions.expect(roles.ColumnsClauseRole, c) for c in column + coercions.expect( + roles.ColumnsClauseRole, c, apply_propagate_attrs=self + ) + for c in column ) @util.deprecated( @@ -1584,7 +1680,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): """ for criterion in list(criterion): - criterion = coercions.expect(roles.WhereHavingRole, criterion) + criterion = coercions.expect( + roles.WhereHavingRole, criterion, apply_propagate_attrs=self + ) # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv if self._aliased_generation: @@ -1742,7 +1840,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): """ self._having_criteria += ( - coercions.expect(roles.WhereHavingRole, criterion), + coercions.expect( + roles.WhereHavingRole, criterion, apply_propagate_attrs=self + ), ) def _set_op(self, expr_fn, *q): @@ -2177,7 +2277,12 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): self._legacy_setup_joins += tuple( ( - coercions.expect(roles.JoinTargetRole, prop[0], legacy=True), + coercions.expect( + roles.JoinTargetRole, + prop[0], + legacy=True, + apply_propagate_attrs=self, + ), prop[1] if len(prop) == 2 else None, None, { @@ -2605,7 +2710,9 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): ORM tutorial """ - statement = coercions.expect(roles.SelectStatementRole, statement) + statement = coercions.expect( + roles.SelectStatementRole, statement, apply_propagate_attrs=self + ) self.compile_options += {"_statement": statement} def first(self): @@ -2711,76 +2818,50 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): def __iter__(self): return self._iter().__iter__() - # TODO: having _iter(), _execute_and_instances, _connection_from_session, - # etc., is all too much. + def _iter(self): + # new style execution. + params = self.load_options._params + statement = self._statement_20(orm_results=True) + result = self.session.execute( + statement, + params, + execution_options={"_sa_orm_load_options": self.load_options}, + ) - # new recipes / extensions should be based on an event hook of some kind, - # can allow an execution that would return a Result to take in all the - # information and return a different Result. this has to be at - # the session / connection .execute() level, and can perhaps be - # before_execute() but needs to be focused around rewriting of results. + # legacy: automatically set scalars, unique + if result._attributes.get("is_single_entity", False): + result = result.scalars() - # the dialect do_execute() *may* be this but that seems a bit too low - # level. it may need to be ORM session based and be a session event, - # becasue it might not invoke the cursor, might invoke for multiple - # connections, etc. OK really has to be a session level event in this - # case to support horizontal sharding. + if result._attributes.get("filtered", False): + result = result.unique() - def _iter(self): - context = self._compile_context() + return result + + def _execute_crud(self, stmt, mapper): + conn = self.session.connection( + mapper=mapper, clause=stmt, close_with_result=True + ) - if self.load_options._autoflush: - self.session._autoflush() - return self._execute_and_instances(context) + return conn._execute_20( + stmt, self.load_options._params, self._execution_options + ) def __str__(self): - compile_state = self._compile_state() + statement = self._statement_20(orm_results=True) + try: bind = ( - self._get_bind_args(compile_state, self.session.get_bind) + self._get_bind_args(statement, self.session.get_bind) if self.session else None ) except sa_exc.UnboundExecutionError: bind = None - return str(compile_state.statement.compile(bind)) - - def _connection_from_session(self, **kw): - conn = self.session.connection(**kw) - if self._execution_options: - conn = conn.execution_options(**self._execution_options) - return conn - - def _execute_and_instances(self, querycontext, params=None): - conn = self._get_bind_args( - querycontext.compile_state, - self._connection_from_session, - close_with_result=True, - ) - if params is None: - params = querycontext.load_options._params + return str(statement.compile(bind)) - result = conn._execute_20( - querycontext.compile_state.statement, - params, - # execution_options=self.session._orm_execution_options(), - ) - return loading.instances(querycontext.query, result, querycontext) - - def _execute_crud(self, stmt, mapper): - conn = self._connection_from_session( - mapper=mapper, clause=stmt, close_with_result=True - ) - - return conn.execute(stmt, self.load_options._params) - - def _get_bind_args(self, compile_state, fn, **kw): - return fn( - mapper=compile_state._bind_mapper(), - clause=compile_state.statement, - **kw - ) + def _get_bind_args(self, statement, fn, **kw): + return fn(clause=statement, **kw) @property def column_descriptions(self): @@ -2837,10 +2918,21 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): "for linking ORM results to arbitrary select constructs.", version="1.4", ) - compile_state = QueryCompileState._create_for_legacy_query(self) - context = QueryContext(compile_state, self.session) + compile_state = ORMCompileState._create_for_legacy_query(self) + context = QueryContext( + compile_state, self.session, self.load_options + ) + + result = loading.instances(result_proxy, context) + + # legacy: automatically set scalars, unique + if result._attributes.get("is_single_entity", False): + result = result.scalars() + + if result._attributes.get("filtered", False): + result = result.unique() - return loading.instances(self, result_proxy, context) + return result def merge_result(self, iterator, load=True): """Merge a result into this :class:`_query.Query` object's Session. @@ -3239,36 +3331,62 @@ class Query(HasPrefixes, HasSuffixes, HasHints, Executable): return update_op.rowcount def _compile_state(self, for_statement=False, **kw): - # TODO: this needs to become a general event for all - # Executable objects as well (all ClauseElement?) - # but then how do we clarify that this event is only for - # *top level* compile, not as an embedded element is visted? - # how does that even work because right now a Query that does things - # like from_self() will in fact invoke before_compile for each - # inner element. - # OK perhaps with 2.0 style folks will continue using before_execute() - # as they can now, as a select() with ORM elements will be delivered - # there, OK. sort of fixes the "bake_ok" problem too. - if self.dispatch.before_compile: - for fn in self.dispatch.before_compile: - new_query = fn(self) - if new_query is not None and new_query is not self: - self = new_query - if not fn._bake_ok: - self.compile_options += {"_bake_ok": False} - - compile_state = QueryCompileState._create_for_legacy_query( + return ORMCompileState._create_for_legacy_query( self, for_statement=for_statement, **kw ) - return compile_state def _compile_context(self, for_statement=False): compile_state = self._compile_state(for_statement=for_statement) - context = QueryContext(compile_state, self.session) + context = QueryContext(compile_state, self.session, self.load_options) return context +class FromStatement(SelectStatementGrouping, Executable): + """Core construct that represents a load of ORM objects from a finished + select or text construct. + + """ + + compile_options = ORMFromStatementCompileState.default_compile_options + + _compile_state_factory = ORMFromStatementCompileState.create_for_statement + + _is_future = True + + _for_update_arg = None + + def __init__(self, entities, element): + self._raw_columns = [ + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) + for ent in util.to_list(entities) + ] + super(FromStatement, self).__init__(element) + + def _compiler_dispatch(self, compiler, **kw): + compile_state = self._compile_state_factory(self, self, **kw) + + toplevel = not compiler.stack + + if toplevel: + compiler.compile_state = compile_state + + return compiler.process(compile_state.statement, **kw) + + def _ensure_disambiguated_names(self): + return self + + def get_children(self, **kw): + for elem in itertools.chain.from_iterable( + element._from_objects for element in self._raw_columns + ): + yield elem + for elem in super(FromStatement, self).get_children(**kw): + yield elem + + class AliasOption(interfaces.LoaderOption): @util.deprecated( "1.4", diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index f539e968fa..e82cd174fc 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2737,12 +2737,12 @@ class JoinCondition(object): def replace(element): if "remote" in element._annotations: - v = element._annotations.copy() + v = dict(element._annotations) del v["remote"] v["local"] = True return element._with_annotations(v) elif "local" in element._annotations: - v = element._annotations.copy() + v = dict(element._annotations) del v["local"] v["remote"] = True return element._with_annotations(v) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 6cb8a0062b..8d2f13df3d 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -12,6 +12,7 @@ import sys import weakref from . import attributes +from . import context from . import exc from . import identity from . import loading @@ -28,13 +29,12 @@ from .base import state_str from .unitofwork import UOWTransaction from .. import engine from .. import exc as sa_exc -from .. import sql +from .. import future from .. import util from ..inspection import inspect from ..sql import coercions from ..sql import roles -from ..sql import util as sql_util - +from ..sql import visitors __all__ = ["Session", "SessionTransaction", "sessionmaker"] @@ -98,6 +98,160 @@ DEACTIVE = util.symbol("DEACTIVE") CLOSED = util.symbol("CLOSED") +class ORMExecuteState(object): + """Stateful object used for the :meth:`.SessionEvents.do_orm_execute` + + .. versionadded:: 1.4 + + """ + + __slots__ = ( + "session", + "statement", + "parameters", + "execution_options", + "bind_arguments", + ) + + def __init__( + self, session, statement, parameters, execution_options, bind_arguments + ): + self.session = session + self.statement = statement + self.parameters = parameters + self.execution_options = execution_options + self.bind_arguments = bind_arguments + + def invoke_statement( + self, + statement=None, + params=None, + execution_options=None, + bind_arguments=None, + ): + """Execute the statement represented by this + :class:`.ORMExecuteState`, without re-invoking events. + + This method essentially performs a re-entrant execution of the + current statement for which the :meth:`.SessionEvents.do_orm_execute` + event is being currently invoked. The use case for this is + for event handlers that want to override how the ultimate results + object is returned, such as for schemes that retrieve results from + an offline cache or which concatenate results from multiple executions. + + :param statement: optional statement to be invoked, in place of the + statement currently represented by :attr:`.ORMExecuteState.statement`. + + :param params: optional dictionary of parameters which will be merged + into the existing :attr:`.ORMExecuteState.parameters` of this + :class:`.ORMExecuteState`. + + :param execution_options: optional dictionary of execution options + will be merged into the existing + :attr:`.ORMExecuteState.execution_options` of this + :class:`.ORMExecuteState`. + + :param bind_arguments: optional dictionary of bind_arguments + which will be merged amongst the current + :attr:`.ORMExecuteState.bind_arguments` + of this :class:`.ORMExecuteState`. + + :return: a :class:`_engine.Result` object with ORM-level results. + + .. seealso:: + + :ref:`examples_caching` - includes example use of the + :meth:`.SessionEvents.do_orm_execute` hook as well as the + :meth:`.ORMExecuteState.invoke_query` method. + + + """ + + if statement is None: + statement = self.statement + + _bind_arguments = dict(self.bind_arguments) + if bind_arguments: + _bind_arguments.update(bind_arguments) + _bind_arguments["_sa_skip_events"] = True + + if params: + _params = dict(self.parameters) + _params.update(params) + else: + _params = self.parameters + + if execution_options: + _execution_options = dict(self.execution_options) + _execution_options.update(execution_options) + else: + _execution_options = self.execution_options + + return self.session.execute( + statement, _params, _execution_options, _bind_arguments + ) + + @property + def orm_query(self): + """Return the :class:`_orm.Query` object associated with this + execution. + + For SQLAlchemy-2.0 style usage, the :class:`_orm.Query` object + is not used at all, and this attribute will return None. + + """ + load_opts = self.load_options + if load_opts._orm_query: + return load_opts._orm_query + + opts = self._orm_compile_options() + if opts is not None: + return opts._orm_query + else: + return None + + def _orm_compile_options(self): + opts = self.statement.compile_options + if isinstance(opts, context.ORMCompileState.default_compile_options): + return opts + else: + return None + + @property + def loader_strategy_path(self): + """Return the :class:`.PathRegistry` for the current load path. + + This object represents the "path" in a query along relationships + when a particular object or collection is being loaded. + + """ + opts = self._orm_compile_options() + if opts is not None: + return opts._current_path + else: + return None + + @property + def load_options(self): + """Return the load_options that will be used for this execution.""" + + return self.execution_options.get( + "_sa_orm_load_options", context.QueryContext.default_load_options + ) + + @property + def user_defined_options(self): + """The sequence of :class:`.UserDefinedOptions` that have been + associated with the statement being invoked. + + """ + return [ + opt + for opt in self.statement._with_options + if not opt._is_compile_state and not opt._is_legacy_option + ] + + class SessionTransaction(object): """A :class:`.Session`-level transaction. @@ -1032,9 +1186,7 @@ class Session(_SessionClassMethods): def connection( self, - mapper=None, - clause=None, - bind=None, + bind_arguments=None, close_with_result=False, execution_options=None, **kw @@ -1059,23 +1211,18 @@ class Session(_SessionClassMethods): resolved through any of the optional keyword arguments. This ultimately makes usage of the :meth:`.get_bind` method for resolution. + :param bind_arguments: dictionary of bind arguments. may include + "mapper", "bind", "clause", other custom arguments that are passed + to :meth:`.Session.get_bind`. + :param bind: - Optional :class:`_engine.Engine` to be used as the bind. If - this engine is already involved in an ongoing transaction, - that connection will be used. This argument takes precedence - over ``mapper``, ``clause``. + deprecated; use bind_arguments :param mapper: - Optional :func:`.mapper` mapped class, used to identify - the appropriate bind. This argument takes precedence over - ``clause``. + deprecated; use bind_arguments :param clause: - A :class:`_expression.ClauseElement` (i.e. - :func:`_expression.select`, - :func:`_expression.text`, - etc.) which will be used to locate a bind, if a bind - cannot otherwise be identified. + deprecated; use bind_arguments :param close_with_result: Passed to :meth:`_engine.Engine.connect`, indicating the :class:`_engine.Connection` should be considered @@ -1097,13 +1244,16 @@ class Session(_SessionClassMethods): :ref:`session_transaction_isolation` :param \**kw: - Additional keyword arguments are sent to :meth:`get_bind()`, - allowing additional arguments to be passed to custom - implementations of :meth:`get_bind`. + deprecated; use bind_arguments """ + + if not bind_arguments: + bind_arguments = kw + + bind = bind_arguments.pop("bind", None) if bind is None: - bind = self.get_bind(mapper, clause=clause, **kw) + bind = self.get_bind(**bind_arguments) return self._connection_for_bind( bind, @@ -1124,7 +1274,14 @@ class Session(_SessionClassMethods): conn = conn.execution_options(**execution_options) return conn - def execute(self, clause, params=None, mapper=None, bind=None, **kw): + def execute( + self, + statement, + params=None, + execution_options=util.immutabledict(), + bind_arguments=None, + **kw + ): r"""Execute a SQL expression construct or string statement within the current transaction. @@ -1222,22 +1379,19 @@ class Session(_SessionClassMethods): "executemany" will be invoked. The keys in each dictionary must correspond to parameter names present in the statement. + :param bind_arguments: dictionary of additional arguments to determine + the bind. may include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + :param mapper: - Optional :func:`.mapper` or mapped class, used to identify - the appropriate bind. This argument takes precedence over - ``clause`` when locating a bind. See :meth:`.Session.get_bind` - for more details. + deprecated; use the bind_arguments dictionary :param bind: - Optional :class:`_engine.Engine` to be used as the bind. If - this engine is already involved in an ongoing transaction, - that connection will be used. This argument takes - precedence over ``mapper`` and ``clause`` when locating - a bind. + deprecated; use the bind_arguments dictionary :param \**kw: - Additional keyword arguments are sent to :meth:`.Session.get_bind()` - to allow extensibility of "bind" schemes. + deprecated; use the bind_arguments dictionary .. seealso:: @@ -1253,20 +1407,63 @@ class Session(_SessionClassMethods): in order to execute the statement. """ - clause = coercions.expect(roles.CoerceTextStatementRole, clause) - if bind is None: - bind = self.get_bind(mapper, clause=clause, **kw) + statement = coercions.expect(roles.CoerceTextStatementRole, statement) - return self._connection_for_bind( - bind, close_with_result=True - )._execute_20(clause, params,) + if not bind_arguments: + bind_arguments = kw + elif kw: + bind_arguments.update(kw) + + compile_state_cls = statement._get_plugin_compile_state_cls("orm") + if compile_state_cls: + compile_state_cls.orm_pre_session_exec( + self, statement, execution_options, bind_arguments + ) + else: + bind_arguments.setdefault("clause", statement) + if statement._is_future: + execution_options = util.immutabledict().merge_with( + execution_options, {"future_result": True} + ) + + if self.dispatch.do_orm_execute: + skip_events = bind_arguments.pop("_sa_skip_events", False) + + if not skip_events: + orm_exec_state = ORMExecuteState( + self, statement, params, execution_options, bind_arguments + ) + for fn in self.dispatch.do_orm_execute: + result = fn(orm_exec_state) + if result: + return result + + bind = self.get_bind(**bind_arguments) + + conn = self._connection_for_bind(bind, close_with_result=True) + result = conn._execute_20(statement, params or {}, execution_options) - def scalar(self, clause, params=None, mapper=None, bind=None, **kw): + if compile_state_cls: + result = compile_state_cls.orm_setup_cursor_result( + self, bind_arguments, result + ) + + return result + + def scalar( + self, + statement, + params=None, + execution_options=None, + mapper=None, + bind=None, + **kw + ): """Like :meth:`~.Session.execute` but return a scalar result.""" return self.execute( - clause, params=params, mapper=mapper, bind=bind, **kw + statement, params=params, mapper=mapper, bind=bind, **kw ).scalar() def close(self): @@ -1422,7 +1619,7 @@ class Session(_SessionClassMethods): """ self._add_bind(table, bind) - def get_bind(self, mapper=None, clause=None): + def get_bind(self, mapper=None, clause=None, bind=None): """Return a "bind" to which this :class:`.Session` is bound. The "bind" is usually an instance of :class:`_engine.Engine`, @@ -1497,6 +1694,8 @@ class Session(_SessionClassMethods): :meth:`.Session.bind_table` """ + if bind: + return bind if mapper is clause is None: if self.bind: @@ -1520,6 +1719,8 @@ class Session(_SessionClassMethods): raise if self.__binds: + # matching mappers and selectables to entries in the + # binds dictionary; supported use case. if mapper: for cls in mapper.class_.__mro__: if cls in self.__binds: @@ -1528,18 +1729,32 @@ class Session(_SessionClassMethods): clause = mapper.persist_selectable if clause is not None: - for t in sql_util.find_tables(clause, include_crud=True): - if t in self.__binds: - return self.__binds[t] + for obj in visitors.iterate(clause): + if obj in self.__binds: + return self.__binds[obj] + # session has a single bind; supported use case. if self.bind: return self.bind - if isinstance(clause, sql.expression.ClauseElement) and clause.bind: - return clause.bind + # now we are in legacy territory. looking for "bind" on tables + # that are via bound metadata. this goes away in 2.0. + if mapper and clause is None: + clause = mapper.persist_selectable - if mapper and mapper.persist_selectable.bind: - return mapper.persist_selectable.bind + if clause is not None: + if clause.bind: + return clause.bind + # for obj in visitors.iterate(clause): + # if obj.bind: + # return obj.bind + + if mapper: + if mapper.persist_selectable.bind: + return mapper.persist_selectable.bind + # for obj in visitors.iterate(mapper.persist_selectable): + # if obj.bind: + # return obj.bind context = [] if mapper is not None: @@ -1722,9 +1937,11 @@ class Session(_SessionClassMethods): else: with_for_update = None + stmt = future.select(object_mapper(instance)) if ( loading.load_on_ident( - self.query(object_mapper(instance)), + self, + stmt, state.key, refresh_state=state, with_for_update=with_for_update, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index c0c090b3d1..a7d501b53f 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -33,6 +33,7 @@ from .util import _none_set from .util import aliased from .. import event from .. import exc as sa_exc +from .. import future from .. import inspect from .. import log from .. import sql @@ -440,10 +441,13 @@ class DeferredColumnLoader(LoaderStrategy): if self.raiseload: self._invoke_raise_load(state, passive, "raise") - query = session.query(localparent) if ( loading.load_on_ident( - query, state.key, only_load_props=group, refresh_state=state + session, + future.select(localparent).apply_labels(), + state.key, + only_load_props=group, + refresh_state=state, ) is None ): @@ -897,7 +901,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): q(session) .with_post_criteria(lambda q: q._set_lazyload_from(state)) ._load_on_pk_identity( - session.query(self.mapper), primary_key_identity + session, session.query(self.mapper), primary_key_identity ) ) @@ -1090,7 +1094,6 @@ class SubqueryLoader(PostLoader): parentmapper=None, **kwargs ): - if ( not compile_state.compile_options._enable_eagerloads or compile_state.compile_options._for_refresh_state @@ -1146,6 +1149,7 @@ class SubqueryLoader(PostLoader): # generate a new Query from the original, then # produce a subquery from it. left_alias = self._generate_from_original_query( + compile_state, orig_query, leftmost_mapper, leftmost_attr, @@ -1164,7 +1168,9 @@ class SubqueryLoader(PostLoader): def set_state_options(compile_state): compile_state.attributes.update( { - ("orig_query", SubqueryLoader): orig_query, + ("orig_query", SubqueryLoader): orig_query.with_session( + None + ), ("subquery_path", None): subq_path, } ) @@ -1188,6 +1194,7 @@ class SubqueryLoader(PostLoader): # by create_row_processor # NOTE: be sure to consult baked.py for some hardcoded logic # about this structure as well + assert q.session is None path.set( compile_state.attributes, "subqueryload_data", {"query": q}, ) @@ -1218,6 +1225,7 @@ class SubqueryLoader(PostLoader): def _generate_from_original_query( self, + orig_compile_state, orig_query, leftmost_mapper, leftmost_attr, @@ -1243,11 +1251,18 @@ class SubqueryLoader(PostLoader): } ) - cs = q._clone() + # NOTE: keystone has a test which is counting before_compile + # events. That test is in one case dependent on an extra + # call that was occurring here within the subqueryloader setup + # process, probably when the subquery() method was called. + # Ultimately that call will not be occurring here. + # the event has already been called on the original query when + # we are here in any case, so keystone will need to adjust that + # test. - # using the _compile_state method so that the before_compile() - # event is hit here. keystone is testing for this. - compile_state = cs._compile_state(entities_only=True) + # for column information, look to the compile state that is + # already being passed through + compile_state = orig_compile_state # select from the identity columns of the outer (specifically, these # are the 'local_cols' of the property). This will remove @@ -1260,7 +1275,6 @@ class SubqueryLoader(PostLoader): ], compile_state._get_current_adapter(), ) - # q.add_columns.non_generative(q, target_cols) q._set_entities(target_cols) distinct_target_key = leftmost_relationship.distinct_target_key @@ -1428,10 +1442,20 @@ class SubqueryLoader(PostLoader): """ - __slots__ = ("subq_info", "subq", "_data") + __slots__ = ( + "session", + "execution_options", + "load_options", + "subq", + "_data", + ) - def __init__(self, subq_info): - self.subq_info = subq_info + def __init__(self, context, subq_info): + # avoid creating a cycle by storing context + # even though that's preferable + self.session = context.session + self.execution_options = context.execution_options + self.load_options = context.load_options self.subq = subq_info["query"] self._data = None @@ -1443,7 +1467,17 @@ class SubqueryLoader(PostLoader): def _load(self): self._data = collections.defaultdict(list) - rows = list(self.subq) + q = self.subq + assert q.session is None + if "compiled_cache" in self.execution_options: + q = q.execution_options( + compiled_cache=self.execution_options["compiled_cache"] + ) + q = q.with_session(self.session) + + # to work with baked query, the parameters may have been + # updated since this query was created, so take these into account + rows = list(q.params(self.load_options._params)) for k, v in itertools.groupby(rows, lambda x: x[1:]): self._data[k].extend(vv[0] for vv in v) @@ -1474,14 +1508,7 @@ class SubqueryLoader(PostLoader): subq = subq_info["query"] - if subq.session is None: - subq.session = context.session - assert subq.session is context.session, ( - "Subquery session doesn't refer to that of " - "our context. Are there broken context caching " - "schemes being used?" - ) - + assert subq.session is None local_cols = self.parent_property.local_columns # cache the loaded collections in the context @@ -1489,7 +1516,7 @@ class SubqueryLoader(PostLoader): # call upon create_row_processor again collections = path.get(context.attributes, "collections") if collections is None: - collections = self._SubqCollections(subq_info) + collections = self._SubqCollections(context, subq_info) path.set(context.attributes, "collections", collections) if adapter: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 1e415e49c3..ce37d962e8 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -41,6 +41,7 @@ from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection @@ -694,6 +695,8 @@ class AliasedInsp( "entity_namespace": self, "compile_state_plugin": "orm", } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} ) @property @@ -748,10 +751,20 @@ class AliasedInsp( ) def _adapt_element(self, elem, key=None): - d = {"parententity": self, "parentmapper": self.mapper} + d = { + "parententity": self, + "parentmapper": self.mapper, + "compile_state_plugin": "orm", + } if key: d["orm_key"] = key - return self._adapter.traverse(elem)._annotate(d) + return ( + self._adapter.traverse(elem) + ._annotate(d) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + ) def _entity_for_mapper(self, mapper): self_poly = self.with_polymorphic_mappers @@ -1037,7 +1050,7 @@ def with_polymorphic( @inspection._self_inspects -class Bundle(ORMColumnsClauseRole, InspectionAttr): +class Bundle(ORMColumnsClauseRole, SupportsCloneAnnotations, InspectionAttr): """A grouping of SQL expressions that are returned by a :class:`.Query` under one namespace. @@ -1070,6 +1083,8 @@ class Bundle(ORMColumnsClauseRole, InspectionAttr): is_bundle = True + _propagate_attrs = util.immutabledict() + def __init__(self, name, *exprs, **kw): r"""Construct a new :class:`.Bundle`. @@ -1090,7 +1105,10 @@ class Bundle(ORMColumnsClauseRole, InspectionAttr): """ self.name = self._label = name self.exprs = exprs = [ - coercions.expect(roles.ColumnsClauseRole, expr) for expr in exprs + coercions.expect( + roles.ColumnsClauseRole, expr, apply_propagate_attrs=self + ) + for expr in exprs ] self.c = self.columns = ColumnCollection( @@ -1145,11 +1163,14 @@ class Bundle(ORMColumnsClauseRole, InspectionAttr): return cloned def __clause_element__(self): + annotations = self._annotations.union( + {"bundle": self, "entity_namespace": self} + ) return expression.ClauseList( _literal_as_text_role=roles.ColumnsClauseRole, group=False, *[e._annotations.get("bundle", e) for e in self.exprs] - )._annotate({"bundle": self, "entity_namespace": self}) + )._annotate(annotations) @property def clauses(self): diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 71d05f38f5..08ed121d30 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -17,8 +17,12 @@ from .traversals import anon_map from .visitors import InternalTraversal from .. import util +EMPTY_ANNOTATIONS = util.immutabledict() + class SupportsAnnotations(object): + _annotations = EMPTY_ANNOTATIONS + @util.memoized_property def _annotations_cache_key(self): anon_map_ = anon_map() @@ -40,7 +44,6 @@ class SupportsAnnotations(object): class SupportsCloneAnnotations(SupportsAnnotations): - _annotations = util.immutabledict() _clone_annotations_traverse_internals = [ ("_annotations", InternalTraversal.dp_annotations_key) @@ -113,12 +116,9 @@ class SupportsWrappingAnnotations(SupportsAnnotations): """ if clone: - # clone is used when we are also copying - # the expression for a deep deannotation - return self._clone() + s = self._clone() + return s else: - # if no clone, since we have no annotations we return - # self return self @@ -163,12 +163,11 @@ class Annotated(object): self.__dict__.pop("_annotations_cache_key", None) self.__dict__.pop("_generate_cache_key", None) self.__element = element - self._annotations = values + self._annotations = util.immutabledict(values) self._hash = hash(element) def _annotate(self, values): - _values = self._annotations.copy() - _values.update(values) + _values = self._annotations.union(values) return self._with_annotations(_values) def _with_annotations(self, values): @@ -183,10 +182,15 @@ class Annotated(object): if values is None: return self.__element else: - _values = self._annotations.copy() - for v in values: - _values.pop(v, None) - return self._with_annotations(_values) + return self._with_annotations( + util.immutabledict( + { + key: value + for key, value in self._annotations.items() + if key not in values + } + ) + ) def _compiler_dispatch(self, visitor, **kw): return self.__element.__class__._compiler_dispatch(self, visitor, **kw) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 04cc344808..bb606a4d6e 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -439,46 +439,53 @@ class CompileState(object): plugins = {} @classmethod - def _create(cls, statement, compiler, **kw): + def create_for_statement(cls, statement, compiler, **kw): # factory construction. - if statement._compile_state_plugin is not None: - constructor = cls.plugins.get( - ( - statement._compile_state_plugin, - statement.__visit_name__, - None, - ), - cls, + if statement._propagate_attrs: + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" ) else: - constructor = cls + plugin_name = "default" + + klass = cls.plugins[(plugin_name, statement.__visit_name__)] - return constructor(statement, compiler, **kw) + if klass is cls: + return cls(statement, compiler, **kw) + else: + return klass.create_for_statement(statement, compiler, **kw) def __init__(self, statement, compiler, **kw): self.statement = statement @classmethod - def get_plugin_classmethod(cls, statement, name): - if statement._compile_state_plugin is not None: - fn = cls.plugins.get( - ( - statement._compile_state_plugin, - statement.__visit_name__, - name, - ), - None, - ) - if fn is not None: - return fn - return getattr(cls, name) + def get_plugin_class(cls, statement): + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + try: + return cls.plugins[(plugin_name, statement.__visit_name__)] + except KeyError: + return None @classmethod - def plugin_for(cls, plugin_name, visit_name, method_name=None): - def decorate(fn): - cls.plugins[(plugin_name, visit_name, method_name)] = fn - return fn + def _get_plugin_compile_state_cls(cls, statement, plugin_name): + statement_plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + if statement_plugin_name != plugin_name: + return None + try: + return cls.plugins[(plugin_name, statement.__visit_name__)] + except KeyError: + return None + + @classmethod + def plugin_for(cls, plugin_name, visit_name): + def decorate(cls_to_decorate): + cls.plugins[(plugin_name, visit_name)] = cls_to_decorate + return cls_to_decorate return decorate @@ -508,12 +515,12 @@ class InPlaceGenerative(HasMemoized): class HasCompileState(Generative): """A class that has a :class:`.CompileState` associated with it.""" - _compile_state_factory = CompileState._create - _compile_state_plugin = None _attributes = util.immutabledict() + _compile_state_factory = CompileState.create_for_statement + class _MetaOptions(type): """metaclass for the Options class.""" @@ -549,6 +556,16 @@ class Options(util.with_metaclass(_MetaOptions)): def add_to_element(self, name, value): return self + {name: getattr(self, name) + value} + @hybridmethod + def _state_dict(self): + return self.__dict__ + + _state_dict_const = util.immutabledict() + + @_state_dict.classlevel + def _state_dict(cls): + return cls._state_dict_const + class CacheableOptions(Options, HasCacheKey): @hybridmethod @@ -590,6 +607,9 @@ class Executable(Generative): def _disable_caching(self): self._cache_enable = HasCacheKey() + def _get_plugin_compile_state_cls(self, plugin_name): + return CompileState._get_plugin_compile_state_cls(self, plugin_name) + @_generative def options(self, *options): """Apply options to this statement. diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 2fc63b82f2..d8ef0222a8 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -50,19 +50,26 @@ def _document_text_coercion(paramname, meth_rst, param_rst): ) -def expect(role, element, **kw): +def expect(role, element, apply_propagate_attrs=None, **kw): # major case is that we are given a ClauseElement already, skip more # elaborate logic up front if possible impl = _impl_lookup[role] if not isinstance( element, - (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue,), + (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue), ): resolved = impl._resolve_for_clause_element(element, **kw) else: resolved = element + if ( + apply_propagate_attrs is not None + and not apply_propagate_attrs._propagate_attrs + and resolved._propagate_attrs + ): + apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs + if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: resolved = impl._post_coercion(resolved, **kw) @@ -106,32 +113,32 @@ class RoleImpl(object): self.name = role_class._role_name self._use_inspection = issubclass(role_class, roles.UsesInspection) - def _resolve_for_clause_element( - self, element, argname=None, apply_plugins=None, **kw - ): + def _resolve_for_clause_element(self, element, argname=None, **kw): original_element = element is_clause_element = False + while hasattr(element, "__clause_element__"): is_clause_element = True if not getattr(element, "is_clause_element", False): element = element.__clause_element__() else: - break - - should_apply_plugins = ( - apply_plugins is not None - and apply_plugins._compile_state_plugin is None - ) + return element + + if not is_clause_element: + if self._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + insp._post_inspect + try: + element = insp.__clause_element__() + except AttributeError: + self._raise_for_expected(original_element, argname) + else: + return element - if is_clause_element: - if ( - should_apply_plugins - and "compile_state_plugin" in element._annotations - ): - apply_plugins._compile_state_plugin = element._annotations[ - "compile_state_plugin" - ] + return self._literal_coercion(element, argname=argname, **kw) + else: return element if self._use_inspection: @@ -142,14 +149,6 @@ class RoleImpl(object): element = insp.__clause_element__() except AttributeError: self._raise_for_expected(original_element, argname) - else: - if ( - should_apply_plugins - and "compile_state_plugin" in element._annotations - ): - plugin = element._annotations["compile_state_plugin"] - apply_plugins._compile_state_plugin = plugin - return element return self._literal_coercion(element, argname=argname, **kw) @@ -649,8 +648,8 @@ class SelectStatementImpl(_NoTextCoercion, RoleImpl): self._raise_for_expected(original_element, argname, resolved) -class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole): - pass +class HasCTEImpl(ReturnsRowsImpl): + __slots__ = () class JoinTargetImpl(RoleImpl): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9a7646743a..8eae0ab7d5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ import contextlib import itertools import operator import re +import time from . import base from . import coercions @@ -380,6 +381,54 @@ class Compiled(object): sub-elements of the statement can modify these. """ + compile_state = None + """Optional :class:`.CompileState` object that maintains additional + state used by the compiler. + + Major executable objects such as :class:`_expression.Insert`, + :class:`_expression.Update`, :class:`_expression.Delete`, + :class:`_expression.Select` will generate this + state when compiled in order to calculate additional information about the + object. For the top level object that is to be executed, the state can be + stored here where it can also have applicability towards result set + processing. + + .. versionadded:: 1.4 + + """ + + _rewrites_selected_columns = False + """if True, indicates the compile_state object rewrites an incoming + ReturnsRows (like a Select) so that the columns we compile against in the + result set are not what were expressed on the outside. this is a hint to + the execution context to not link the statement.selected_columns to the + columns mapped in the result object. + + That is, when this flag is False:: + + stmt = some_statement() + + result = conn.execute(stmt) + row = result.first() + + # selected_columns are in a 1-1 relationship with the + # columns in the result, and are targetable in mapping + for col in stmt.selected_columns: + assert col in row._mapping + + When True:: + + # selected columns are not what are in the rows. the context + # rewrote the statement for some other set of selected_columns. + for col in stmt.selected_columns: + assert col not in row._mapping + + + """ + + cache_key = None + _gen_time = None + def __init__( self, dialect, @@ -433,6 +482,7 @@ class Compiled(object): self.string = self.preparer._render_schema_translates( self.string, schema_translate_map ) + self._gen_time = time.time() def _execute_on_connection( self, connection, multiparams, params, execution_options @@ -637,28 +687,6 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () - compile_state = None - """Optional :class:`.CompileState` object that maintains additional - state used by the compiler. - - Major executable objects such as :class:`_expression.Insert`, - :class:`_expression.Update`, :class:`_expression.Delete`, - :class:`_expression.Select` will generate this - state when compiled in order to calculate additional information about the - object. For the top level object that is to be executed, the state can be - stored here where it can also have applicability towards result set - processing. - - .. versionadded:: 1.4 - - """ - - compile_state_factories = util.immutabledict() - """Dictionary of alternate :class:`.CompileState` factories for given - classes, identified by their visit_name. - - """ - def __init__( self, dialect, @@ -667,7 +695,6 @@ class SQLCompiler(Compiled): column_keys=None, inline=False, linting=NO_LINTING, - compile_state_factories=None, **kwargs ): """Construct a new :class:`.SQLCompiler` object. @@ -734,9 +761,6 @@ class SQLCompiler(Compiled): # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} - if compile_state_factories: - self.compile_state_factories = compile_state_factories - Compiled.__init__(self, dialect, statement, **kwargs) if ( @@ -1542,7 +1566,7 @@ class SQLCompiler(Compiled): compile_state = cs._compile_state_factory(cs, self, **kwargs) - if toplevel: + if toplevel and not self.compile_state: self.compile_state = compile_state entry = self._default_stack_entry if toplevel else self.stack[-1] @@ -2541,6 +2565,13 @@ class SQLCompiler(Compiled): ) return froms + translate_select_structure = None + """if none None, should be a callable which accepts (select_stmt, **kw) + and returns a select object. this is used for structural changes + mostly to accommodate for LIMIT/OFFSET schemes + + """ + def visit_select( self, select_stmt, @@ -2552,7 +2583,17 @@ class SQLCompiler(Compiled): from_linter=None, **kwargs ): + assert select_wraps_for is None, ( + "SQLAlchemy 1.4 requires use of " + "the translate_select_structure hook for structural " + "translations of SELECT objects" + ) + # initial setup of SELECT. the compile_state_factory may now + # be creating a totally different SELECT from the one that was + # passed in. for ORM use this will convert from an ORM-state + # SELECT to a regular "Core" SELECT. other composed operations + # such as computation of joins will be performed. compile_state = select_stmt._compile_state_factory( select_stmt, self, **kwargs ) @@ -2560,9 +2601,29 @@ class SQLCompiler(Compiled): toplevel = not self.stack - if toplevel: + if toplevel and not self.compile_state: self.compile_state = compile_state + # translate step for Oracle, SQL Server which often need to + # restructure the SELECT to allow for LIMIT/OFFSET and possibly + # other conditions + if self.translate_select_structure: + new_select_stmt = self.translate_select_structure( + select_stmt, asfrom=asfrom, **kwargs + ) + + # if SELECT was restructured, maintain a link to the originals + # and assemble a new compile state + if new_select_stmt is not select_stmt: + compile_state_wraps_for = compile_state + select_wraps_for = select_stmt + select_stmt = new_select_stmt + + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + entry = self._default_stack_entry if toplevel else self.stack[-1] populate_result_map = need_column_expressions = ( @@ -2624,13 +2685,9 @@ class SQLCompiler(Compiled): ] if populate_result_map and select_wraps_for is not None: - # if this select is a compiler-generated wrapper, + # if this select was generated from translate_select, # rewrite the targeted columns in the result map - compile_state_wraps_for = select_wraps_for._compile_state_factory( - select_wraps_for, self, **kwargs - ) - translate = dict( zip( [ @@ -3013,7 +3070,8 @@ class SQLCompiler(Compiled): if toplevel: self.isinsert = True - self.compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state self.stack.append( { diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 3dc4e917cf..467a764d62 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -39,54 +39,8 @@ class DMLState(CompileState): isdelete = False isinsert = False - @classmethod - def _create_insert(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isinsert=True, **kw) - - @classmethod - def _create_update(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isupdate=True, **kw) - - @classmethod - def _create_delete(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isdelete=True, **kw) - - def __init__( - self, - statement, - compiler, - isinsert=False, - isupdate=False, - isdelete=False, - **kw - ): - self.statement = statement - - if isupdate: - self.isupdate = True - self._preserve_parameter_order = ( - statement._preserve_parameter_order - ) - if statement._ordered_values is not None: - self._process_ordered_values(statement) - elif statement._values is not None: - self._process_values(statement) - elif statement._multi_values: - self._process_multi_values(statement) - self._extra_froms = self._make_extra_froms(statement) - elif isinsert: - self.isinsert = True - if statement._select_names: - self._process_select_values(statement) - if statement._values is not None: - self._process_values(statement) - if statement._multi_values: - self._process_multi_values(statement) - elif isdelete: - self.isdelete = True - self._extra_froms = self._make_extra_froms(statement) - else: - assert False, "one of isinsert, isupdate, or isdelete must be set" + def __init__(self, statement, compiler, **kw): + raise NotImplementedError() def _make_extra_froms(self, statement): froms = [] @@ -174,6 +128,51 @@ class DMLState(CompileState): ) +@CompileState.plugin_for("default", "insert") +class InsertDMLState(DMLState): + isinsert = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isinsert = True + if statement._select_names: + self._process_select_values(statement) + if statement._values is not None: + self._process_values(statement) + if statement._multi_values: + self._process_multi_values(statement) + + +@CompileState.plugin_for("default", "update") +class UpdateDMLState(DMLState): + isupdate = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isupdate = True + self._preserve_parameter_order = statement._preserve_parameter_order + if statement._ordered_values is not None: + self._process_ordered_values(statement) + elif statement._values is not None: + self._process_values(statement) + elif statement._multi_values: + self._process_multi_values(statement) + self._extra_froms = self._make_extra_froms(statement) + + +@CompileState.plugin_for("default", "delete") +class DeleteDMLState(DMLState): + isdelete = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isdelete = True + self._extra_froms = self._make_extra_froms(statement) + + class UpdateBase( roles.DMLRole, HasCTE, @@ -754,8 +753,6 @@ class Insert(ValuesBase): _supports_multi_parameters = True - _compile_state_factory = DMLState._create_insert - select = None include_insert_from_select_defaults = False @@ -964,8 +961,6 @@ class Update(DMLWhereBase, ValuesBase): __visit_name__ = "update" - _compile_state_factory = DMLState._create_update - _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), @@ -1210,8 +1205,6 @@ class Delete(DMLWhereBase, UpdateBase): __visit_name__ = "delete" - _compile_state_factory = DMLState._create_delete - _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index c1bc9edbcf..287e537242 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -191,7 +191,12 @@ class ClauseElement( __visit_name__ = "clause" - _annotations = {} + _propagate_attrs = util.immutabledict() + """like annotations, however these propagate outwards liberally + as SQL constructs are built, and are set up at construction time. + + """ + supports_execution = False _from_objects = [] bind = None @@ -215,6 +220,16 @@ class ClauseElement( _cache_key_traversal = None + def _set_propagate_attrs(self, values): + # usually, self._propagate_attrs is empty here. one case where it's + # not is a subquery against ORM select, that is then pulled as a + # property of an aliased class. should all be good + + # assert not self._propagate_attrs + + self._propagate_attrs = util.immutabledict(values) + return self + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -870,6 +885,7 @@ class ColumnElement( type_=getattr(self, "type", None), _selectable=selectable, ) + co._propagate_attrs = selectable._propagate_attrs co._proxies = [self] if selectable._is_clone_of is not None: co._is_clone_of = selectable._is_clone_of.columns.get(key) @@ -1495,6 +1511,8 @@ class TextClause( _render_label_in_columns_clause = False + _hide_froms = () + def __and__(self, other): # support use in select.where(), query.filter() return and_(self, other) @@ -1509,10 +1527,6 @@ class TextClause( _allow_label_resolve = False - @property - def _hide_froms(self): - return [] - def __init__(self, text, bind=None): self._bind = bind self._bindparams = {} @@ -2093,14 +2107,16 @@ class ClauseList( ) if self.group_contents: self.clauses = [ - coercions.expect(text_converter_role, clause).self_group( - against=self.operator - ) + coercions.expect( + text_converter_role, clause, apply_propagate_attrs=self + ).self_group(against=self.operator) for clause in clauses ] else: self.clauses = [ - coercions.expect(text_converter_role, clause) + coercions.expect( + text_converter_role, clause, apply_propagate_attrs=self + ) for clause in clauses ] self._is_implicitly_boolean = operators.is_boolean(self.operator) @@ -2641,7 +2657,9 @@ class Case(ColumnElement): whenlist = [ ( coercions.expect( - roles.ExpressionElementRole, c + roles.ExpressionElementRole, + c, + apply_propagate_attrs=self, ).self_group(), coercions.expect(roles.ExpressionElementRole, r), ) @@ -2650,7 +2668,9 @@ class Case(ColumnElement): else: whenlist = [ ( - coercions.expect(roles.ColumnArgumentRole, c).self_group(), + coercions.expect( + roles.ColumnArgumentRole, c, apply_propagate_attrs=self + ).self_group(), coercions.expect(roles.ExpressionElementRole, r), ) for (c, r) in whens @@ -2805,7 +2825,10 @@ class Cast(WrapsColumnExpression, ColumnElement): """ self.type = type_api.to_instance(type_) self.clause = coercions.expect( - roles.ExpressionElementRole, expression, type_=self.type + roles.ExpressionElementRole, + expression, + type_=self.type, + apply_propagate_attrs=self, ) self.typeclause = TypeClause(self.type) @@ -2906,7 +2929,10 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): """ self.type = type_api.to_instance(type_) self.clause = coercions.expect( - roles.ExpressionElementRole, expression, type_=self.type + roles.ExpressionElementRole, + expression, + type_=self.type, + apply_propagate_attrs=self, ) @property @@ -3031,6 +3057,7 @@ class UnaryExpression(ColumnElement): ): self.operator = operator self.modifier = modifier + self._propagate_attrs = element._propagate_attrs self.element = element.self_group( against=self.operator or self.modifier ) @@ -3474,6 +3501,7 @@ class BinaryExpression(ColumnElement): if isinstance(operator, util.string_types): operator = operators.custom_op(operator) self._orig = (left.__hash__(), right.__hash__()) + self._propagate_attrs = left._propagate_attrs or right._propagate_attrs self.left = left.self_group(against=operator) self.right = right.self_group(against=operator) self.operator = operator @@ -4159,6 +4187,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): name=name if name else self.name, disallow_is_literal=True, ) + e._propagate_attrs = selectable._propagate_attrs e._proxies.append(self) if self._type is not None: e.type = self._type @@ -4340,16 +4369,10 @@ class ColumnClause( return other.proxy_set.intersection(self.proxy_set) def get_children(self, column_tables=False, **kw): - if column_tables and self.table is not None: - # TODO: this is only used by ORM query deep_entity_zero. - # this is being removed in a later release so remove - # column_tables also at that time. - return [self.table] - else: - # override base get_children() to not return the Table - # or selectable that is parent to this column. Traversals - # expect the columns of tables and subqueries to be leaf nodes. - return [] + # override base get_children() to not return the Table + # or selectable that is parent to this column. Traversals + # expect the columns of tables and subqueries to be leaf nodes. + return [] @HasMemoized.memoized_attribute def _from_objects(self): @@ -4474,6 +4497,7 @@ class ColumnClause( _selectable=selectable, is_literal=is_literal, ) + c._propagate_attrs = selectable._propagate_attrs if name is None: c.key = self.key c._proxies = [self] diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index cedb76f559..6b1172eba4 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -107,6 +107,7 @@ class FunctionElement(Executable, ColumnElement, FromClause): roles.ExpressionElementRole, c, name=getattr(self, "name", None), + apply_propagate_attrs=self, ) for c in clauses ] @@ -749,7 +750,10 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): if parsed_args is None: parsed_args = [ coercions.expect( - roles.ExpressionElementRole, c, name=self.name + roles.ExpressionElementRole, + c, + name=self.name, + apply_propagate_attrs=self, ) for c in args ] @@ -813,7 +817,12 @@ class ReturnTypeFromArgs(GenericFunction): def __init__(self, *args, **kwargs): args = [ - coercions.expect(roles.ExpressionElementRole, c, name=self.name) + coercions.expect( + roles.ExpressionElementRole, + c, + name=self.name, + apply_propagate_attrs=self, + ) for c in args ] kwargs.setdefault("type_", _type_from_args(args)) @@ -944,7 +953,12 @@ class array_agg(GenericFunction): type = sqltypes.ARRAY def __init__(self, *args, **kwargs): - args = [coercions.expect(roles.ExpressionElementRole, c) for c in args] + args = [ + coercions.expect( + roles.ExpressionElementRole, c, apply_propagate_attrs=self + ) + for c in args + ] default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) if "type_" not in kwargs: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index b861f721b0..d0f4fef601 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -142,12 +142,20 @@ class AnonymizedFromClauseRole(StrictFromClauseRole): class CoerceTextStatementRole(SQLRole): - _role_name = "Executable SQL, text() construct, or string statement" + _role_name = "Executable SQL or text() construct" + + +# _executable_statement = None class StatementRole(CoerceTextStatementRole): _role_name = "Executable SQL or text() construct" + _is_future = False + + def _get_plugin_compile_state_cls(self, name): + return None + class ReturnsRowsRole(StatementRole): _role_name = ( diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 65f8bd81c0..263f579def 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1632,6 +1632,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) c.table = selectable + c._propagate_attrs = selectable._propagate_attrs if selectable._is_clone_of is not None: c._is_clone_of = selectable._is_clone_of.columns.get(c.key) if self.primary_key: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6a552c18c7..008959aec4 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1342,7 +1342,9 @@ class AliasedReturnsRows(NoInit, FromClause): raise NotImplementedError() def _init(self, selectable, name=None): - self.element = selectable + self.element = coercions.expect( + roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self + ) self.supports_execution = selectable.supports_execution if self.supports_execution: self._execution_options = selectable._execution_options @@ -3026,6 +3028,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): ) +@CompileState.plugin_for("default", "compound_select") class CompoundSelectState(CompileState): @util.memoized_property def _label_resolve_dict(self): @@ -3058,7 +3061,6 @@ class CompoundSelect(HasCompileState, GenerativeSelect): """ __visit_name__ = "compound_select" - _compile_state_factory = CompoundSelectState._create _traverse_internals = [ ("selects", InternalTraversal.dp_clauseelement_list), @@ -3425,6 +3427,7 @@ class DeprecatedSelectGenerations(object): self.select_from.non_generative(self, fromclause) +@CompileState.plugin_for("default", "select") class SelectState(CompileState): class default_select_compile_options(CacheableOptions): _cache_key_traversal = [] @@ -3462,7 +3465,7 @@ class SelectState(CompileState): ) if not seen.intersection(item._cloned_set): froms.append(item) - seen.update(item._cloned_set) + seen.update(item._cloned_set) return froms @@ -3714,12 +3717,29 @@ class SelectState(CompileState): return replace_from_obj_index +class _SelectFromElements(object): + def _iterate_from_elements(self): + # note this does not include elements + # in _setup_joins or _legacy_setup_joins + + return itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in self._raw_columns] + ), + itertools.chain.from_iterable( + [element._from_objects for element in self._where_criteria] + ), + self._from_obj, + ) + + class Select( HasPrefixes, HasSuffixes, HasHints, HasCompileState, DeprecatedSelectGenerations, + _SelectFromElements, GenerativeSelect, ): """Represents a ``SELECT`` statement. @@ -3728,7 +3748,6 @@ class Select( __visit_name__ = "select" - _compile_state_factory = SelectState._create _is_future = False _setup_joins = () _legacy_setup_joins = () @@ -4047,7 +4066,7 @@ class Select( if cols_present: self._raw_columns = [ coercions.expect( - roles.ColumnsClauseRole, c, apply_plugins=self + roles.ColumnsClauseRole, c, apply_propagate_attrs=self ) for c in columns ] @@ -4073,17 +4092,6 @@ class Select( cols = list(elem._select_iterable) return cols[0].type - def _iterate_from_elements(self): - return itertools.chain( - itertools.chain.from_iterable( - [element._from_objects for element in self._raw_columns] - ), - itertools.chain.from_iterable( - [element._from_objects for element in self._where_criteria] - ), - self._from_obj, - ) - @property def froms(self): """Return the displayed list of FromClause elements.""" @@ -4192,14 +4200,16 @@ class Select( self._raw_columns = self._raw_columns + [ coercions.expect( - roles.ColumnsClauseRole, column, apply_plugins=self + roles.ColumnsClauseRole, column, apply_propagate_attrs=self ) for column in columns ] def _set_entities(self, entities): self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self) + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) for ent in util.to_list(entities) ] @@ -4342,14 +4352,24 @@ class Select( self._raw_columns = rc @property - def _whereclause(self): - """Legacy, return the WHERE clause as a """ - """:class:`_expression.BooleanClauseList`""" + def whereclause(self): + """Return the completed WHERE clause for this :class:`.Select` + statement. + + This assembles the current collection of WHERE criteria + into a single :class:`_expression.BooleanClauseList` construct. + + + .. versionadded:: 1.4 + + """ return BooleanClauseList._construct_for_whereclause( self._where_criteria ) + _whereclause = whereclause + @_generative def where(self, whereclause): """return a new select() construct with the given expression added to @@ -4430,7 +4450,7 @@ class Select( self._from_obj += tuple( coercions.expect( - roles.FromClauseRole, fromclause, apply_plugins=self + roles.FromClauseRole, fromclause, apply_propagate_attrs=self ) for fromclause in froms ) diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index a308feb7ca..482248ada4 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -179,7 +179,7 @@ class HasCacheKey(object): if NO_CACHE in _anon_map: return None else: - return CacheKey(key, bindparams) + return CacheKey(key, bindparams, self) @classmethod def _generate_cache_key_for_object(cls, obj): @@ -190,7 +190,7 @@ class HasCacheKey(object): if NO_CACHE in _anon_map: return None else: - return CacheKey(key, bindparams) + return CacheKey(key, bindparams, obj) class MemoizedHasCacheKey(HasCacheKey, HasMemoized): @@ -199,9 +199,42 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized): return HasCacheKey._generate_cache_key(self) -class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): +class CacheKey(namedtuple("CacheKey", ["key", "bindparams", "statement"])): def __hash__(self): - return hash(self.key) + """CacheKey itself is not hashable - hash the .key portion""" + + return None + + def to_offline_string(self, statement_cache, parameters): + """generate an "offline string" form of this :class:`.CacheKey` + + The "offline string" is basically the string SQL for the + statement plus a repr of the bound parameter values in series. + Whereas the :class:`.CacheKey` object is dependent on in-memory + identities in order to work as a cache key, the "offline" version + is suitable for a cache that will work for other processes as well. + + The given "statement_cache" is a dictionary-like object where the + string form of the statement itself will be cached. this dictionary + should be in a longer lived scope in order to reduce the time spent + stringifying statements. + + + """ + if self.key not in statement_cache: + statement_cache[self.key] = sql_str = str(self.statement) + else: + sql_str = statement_cache[self.key] + + return repr( + ( + sql_str, + tuple( + parameters.get(bindparam.key, bindparam.value) + for bindparam in self.bindparams + ), + ) + ) def __eq__(self, other): return self.key == other.key @@ -411,7 +444,6 @@ class _CacheKey(ExtendedInternalTraversal): def visit_setup_join_tuple( self, attrname, obj, parent, anon_map, bindparams ): - # TODO: look at attrname for "legacy_join" and use different structure return tuple( ( target._gen_cache_key(anon_map, bindparams), @@ -596,7 +628,6 @@ class _CopyInternals(InternalTraversal): def visit_setup_join_tuple( self, attrname, parent, element, clone=_clone, **kw ): - # TODO: look at attrname for "legacy_join" and use different structure return tuple( ( clone(target, **kw) if target is not None else None, @@ -668,6 +699,15 @@ class _CopyInternals(InternalTraversal): _copy_internals = _CopyInternals() +def _flatten_clauseelement(element): + while hasattr(element, "__clause_element__") and not getattr( + element, "is_clause_element", False + ): + element = element.__clause_element__() + + return element + + class _GetChildren(InternalTraversal): """Generate a _children_traversal internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -696,6 +736,17 @@ class _GetChildren(InternalTraversal): def visit_clauseelement_unordered_set(self, element, **kw): return element + def visit_setup_join_tuple(self, element, **kw): + for (target, onclause, from_, flags) in element: + if from_ is not None: + yield from_ + + if not isinstance(target, str): + yield _flatten_clauseelement(target) + + # if onclause is not None and not isinstance(onclause, str): + # yield _flatten_clauseelement(onclause) + def visit_dml_ordered_values(self, element, **kw): for k, v in element: if hasattr(k, "__clause_element__"): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 030fd2fdeb..683f545dd0 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -591,6 +591,7 @@ def iterate(obj, opts=util.immutabledict()): """ yield obj children = obj.get_children(**opts) + if not children: return diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 24e96dfab4..92bd452a52 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -294,6 +294,7 @@ def count_functions(variance=0.05): print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) stats.sort_stats(_profile_stats.sort) stats.print_stats() + # stats.print_callers() if _profile_stats.force_write: _profile_stats.replace(callcount) elif expected_count: diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 9a832ba1b0..6056864947 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -97,17 +97,13 @@ class FacadeDict(ImmutableContainer, dict): def __new__(cls, *args): new = dict.__new__(cls) - dict.__init__(new, *args) return new - def __init__(self, *args): - pass - - # note that currently, "copy()" is used as a way to get a plain dict - # from an immutabledict, while also allowing the method to work if the - # dictionary is already a plain dict. - # def copy(self): - # return immutabledict.__new__(immutabledict, self) + def copy(self): + raise NotImplementedError( + "an immutabledict shouldn't need to be copied. use dict(d) " + "if you need a mutable dictionary." + ) def __reduce__(self): return FacadeDict, (dict(self),) diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 494e078ab1..2d5acca3bb 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -129,7 +129,9 @@ def profile_memory( ) raise + gc_collect() + samples.append( get_num_objects() if get_num_objects is not None @@ -1299,7 +1301,7 @@ class CycleTest(_fixtures.FixtureTest): # others as of cache key. The options themselves are now part of # QueryCompileState which is not eagerly disposed yet, so this # adds some more. - @assert_cycles(36) + @assert_cycles(37) def go(): generate() @@ -1370,7 +1372,7 @@ class CycleTest(_fixtures.FixtureTest): @assert_cycles(4) def go(): - result = s.execute(stmt) + result = s.connection(mapper=User).execute(stmt) while True: row = result.fetchone() if row is None: diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index ac6e6b55e4..7456d5f5b7 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -834,7 +834,7 @@ class JoinedEagerLoadTest(fixtures.MappedTest): def test_fetch_results(self): A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") - sess = Session() + sess = Session(testing.db) q = sess.query(A).options( joinedload(A.bs).joinedload(B.cs).joinedload(C.ds), @@ -842,16 +842,26 @@ class JoinedEagerLoadTest(fixtures.MappedTest): defaultload(A.es).joinedload(E.gs), ) - context = q._compile_context() - compile_state = context.compile_state - orig_attributes = dict(compile_state.attributes) + compile_state = q._compile_state() + + from sqlalchemy.orm.context import ORMCompileState @profiling.function_call_count() def go(): for i in range(100): - # make sure these get reset each time - context.attributes = orig_attributes.copy() - obj = q._execute_and_instances(context) + exec_opts = {} + bind_arguments = {} + ORMCompileState.orm_pre_session_exec( + sess, compile_state.query, exec_opts, bind_arguments + ) + + r = sess.connection().execute( + compile_state.statement, + execution_options=exec_opts, + bind_arguments=bind_arguments, + ) + r.context.compiled.compile_state = compile_state + obj = ORMCompileState.orm_setup_cursor_result(sess, {}, r) list(obj) sess.close() diff --git a/test/base/test_result.py b/test/base/test_result.py index b179c34620..6cffcc3231 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -115,6 +115,19 @@ class ResultTupleTest(fixtures.TestBase): is_true("a" not in keyed_tuple) is_true("z" not in keyed_tuple) + def test_contains_mapping(self): + keyed_tuple = self._fixture(["x", "y"], ["a", "b"])._mapping + + is_false("x" in keyed_tuple) + is_false("z" in keyed_tuple) + + is_true("z" not in keyed_tuple) + is_true("x" not in keyed_tuple) + + # we do keys + is_true("a" in keyed_tuple) + is_true("b" in keyed_tuple) + def test_none_label(self): keyed_tuple = self._fixture([1, 2, 3], ["a", None, "b"]) eq_(str(keyed_tuple), "(1, 2, 3)") @@ -841,15 +854,12 @@ class OnlyScalarsTest(fixtures.TestBase): def no_tuple_fixture(self): data = [(1, 1, 1), (2, 1, 2), (1, 1, 1), (1, 3, 2), (4, 1, 2)] - def chunks(num, as_tuples): + def chunks(num): while data: rows = data[0:num] data[:] = [] - if as_tuples: - assert False - else: - yield [row[0] for row in rows] + yield [row[0] for row in rows] return chunks @@ -857,15 +867,12 @@ class OnlyScalarsTest(fixtures.TestBase): def normal_fixture(self): data = [(1, 1, 1), (2, 1, 2), (1, 1, 1), (1, 3, 2), (4, 1, 2)] - def chunks(num, as_tuples): + def chunks(num): while data: rows = data[0:num] data[:] = [] - if as_tuples: - yield rows - else: - yield [row[0] for row in rows] + yield [row[0] for row in rows] return chunks @@ -891,6 +898,26 @@ class OnlyScalarsTest(fixtures.TestBase): eq_(r.all(), [1, 2, 4]) + def test_scalar_mode_scalars_fetchmany(self, normal_fixture): + metadata = result.SimpleResultMetaData(["a", "b", "c"]) + + r = result.ChunkedIteratorResult( + metadata, normal_fixture, source_supports_scalars=True + ) + + r = r.scalars() + eq_(list(r.partitions(2)), [[1, 2], [1, 1], [4]]) + + def test_scalar_mode_unique_scalars_fetchmany(self, normal_fixture): + metadata = result.SimpleResultMetaData(["a", "b", "c"]) + + r = result.ChunkedIteratorResult( + metadata, normal_fixture, source_supports_scalars=True + ) + + r = r.scalars().unique() + eq_(list(r.partitions(2)), [[1, 2], [4]]) + def test_scalar_mode_unique_tuples_all(self, normal_fixture): metadata = result.SimpleResultMetaData(["a", "b", "c"]) @@ -900,7 +927,7 @@ class OnlyScalarsTest(fixtures.TestBase): r = r.unique() - eq_(r.all(), [(1, 1, 1), (2, 1, 2), (1, 3, 2), (4, 1, 2)]) + eq_(r.all(), [(1,), (2,), (4,)]) def test_scalar_mode_tuples_all(self, normal_fixture): metadata = result.SimpleResultMetaData(["a", "b", "c"]) @@ -909,7 +936,7 @@ class OnlyScalarsTest(fixtures.TestBase): metadata, normal_fixture, source_supports_scalars=True ) - eq_(r.all(), [(1, 1, 1), (2, 1, 2), (1, 1, 1), (1, 3, 2), (4, 1, 2)]) + eq_(r.all(), [(1,), (2,), (1,), (1,), (4,)]) def test_scalar_mode_scalars_iterate(self, no_tuple_fixture): metadata = result.SimpleResultMetaData(["a", "b", "c"]) @@ -929,4 +956,4 @@ class OnlyScalarsTest(fixtures.TestBase): metadata, normal_fixture, source_supports_scalars=True ) - eq_(list(r), [(1, 1, 1), (2, 1, 2), (1, 1, 1), (1, 3, 2), (4, 1, 2)]) + eq_(list(r), [(1,), (2,), (1,), (1,), (4,)]) diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index 77e57aa367..ecb5e3919b 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -289,6 +289,7 @@ class LikeQueryTest(BakedTest): # with multiple params, the **kwargs will be used bq += lambda q: q.filter(User.id == bindparam("anid")) eq_(bq(sess).params(uname="fred", anid=9).count(), 1) + eq_( # wrong id, so 0 results: bq(sess).params(uname="fred", anid=8).count(), @@ -388,7 +389,12 @@ class ResultPostCriteriaTest(BakedTest): def before_execute( conn, clauseelement, multiparams, params, execution_options ): - assert "yes" in conn._execution_options + # execution options are kind of moving around a bit, + # test both places + assert ( + "yes" in clauseelement._execution_options + or "yes" in execution_options + ) bq = self.bakery(lambda s: s.query(User.id).order_by(User.id)) @@ -804,9 +810,7 @@ class ResultTest(BakedTest): Address = self.classes.Address Order = self.classes.Order - # Override the default bakery for one with a smaller size. This used to - # trigger a bug when unbaking subqueries. - self.bakery = baked.bakery(size=3) + self.bakery = baked.bakery() base_bq = self.bakery(lambda s: s.query(User)) base_bq += lambda q: q.options( @@ -840,6 +844,7 @@ class ResultTest(BakedTest): for cond1, cond2 in itertools.product( *[(False, True) for j in range(2)] ): + print("HI----") bq = base_bq._clone() sess = Session() @@ -903,7 +908,7 @@ class ResultTest(BakedTest): ) ] - self.bakery = baked.bakery(size=3) + self.bakery = baked.bakery() bq = self.bakery(lambda s: s.query(User)) @@ -1288,33 +1293,72 @@ class LazyLoaderTest(testing.AssertsCompiledSQL, BakedTest): def _test_baked_lazy_loading_relationship_flag(self, flag): User, Address = self._o2m_fixture(bake_queries=flag) + from sqlalchemy import inspect - sess = Session() - u1 = sess.query(User).first() - - from sqlalchemy.orm import Query - - canary = mock.Mock() + address_mapper = inspect(Address) + sess = Session(testing.db) + + # there's no event in the compile process either at the ORM + # or core level and it is not easy to patch. the option object + # is the one thing that will get carried into the lazyload from the + # outside and invoked on a per-compile basis + mock_opt = mock.Mock( + _is_compile_state=True, + propagate_to_loaders=True, + _gen_cache_key=lambda *args: ("hi",), + _generate_path_cache_key=lambda path: ("hi",), + ) - # I would think Mock can do this but apparently - # it cannot (wrap / autospec don't work together) - real_compile_state = Query._compile_state + u1 = sess.query(User).options(mock_opt).first() - def _my_compile_state(*arg, **kw): - if arg[0].column_descriptions[0]["entity"] is Address: - canary() - return real_compile_state(*arg, **kw) + @event.listens_for(sess, "do_orm_execute") + def _my_compile_state(context): + if ( + context.statement._raw_columns[0]._annotations["parententity"] + is address_mapper + ): + mock_opt.orm_execute() - with mock.patch.object(Query, "_compile_state", _my_compile_state): - u1.addresses + u1.addresses - sess.expire(u1) - u1.addresses + sess.expire(u1) + u1.addresses if flag: - eq_(canary.call_count, 1) + eq_( + mock_opt.mock_calls, + [ + mock.call.process_query(mock.ANY), + mock.call.process_compile_state(mock.ANY), # query.first() + mock.call.process_query_conditionally(mock.ANY), + mock.call.orm_execute(), # lazyload addresses + mock.call.process_compile_state(mock.ANY), # emit lazyload + mock.call.process_compile_state( + mock.ANY + ), # load scalar attributes for user + # lazyload addresses, no call to process_compile_state + mock.call.orm_execute(), + ], + ) else: - eq_(canary.call_count, 2) + eq_( + mock_opt.mock_calls, + [ + mock.call.process_query(mock.ANY), + mock.call.process_compile_state(mock.ANY), # query.first() + mock.call.process_query_conditionally(mock.ANY), + mock.call.orm_execute(), # lazyload addresses + mock.call.process_compile_state(mock.ANY), # emit_lazyload + mock.call.process_compile_state( + mock.ANY + ), # load_scalar_attributes for user + mock.call.process_query_conditionally(mock.ANY), + mock.call.orm_execute(), # lazyload addresses + mock.call.process_compile_state( + mock.ANY + ), # emit_lazyload, here the query was not cached + ], + ) def test_baked_lazy_loading_option_o2m(self): User, Address = self._o2m_fixture() @@ -1571,58 +1615,57 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): return User, Address def _query_fixture(self): - from sqlalchemy.orm.query import Query, _generative + from sqlalchemy.orm.query import Query class CachingQuery(Query): cache = {} - @_generative def set_cache_key(self, key): - self._cache_key = key - - # in 1.4 / If1a23824ffb77d8d58cf2338cf35dd6b5963b17f , - # we no longer override ``__iter__`` because we need the - # whole result object. The FrozenResult is added for this - # use case. A new session-level event will be added within - # the scope of ORM /execute() integration so that people - # don't have to subclass this anymore. - - def _execute_and_instances(self, context, **kw): - super_ = super(CachingQuery, self) - - if hasattr(self, "_cache_key"): - return self.get_value( - createfunc=lambda: super_._execute_and_instances( - context, **kw - ) - ) - else: - return super_._execute_and_instances(context, **kw) - - def get_value(self, createfunc): - if self._cache_key in self.cache: - return self.cache[self._cache_key]() - else: - self.cache[ - self._cache_key - ] = retval = createfunc().freeze() - return retval() + return self.execution_options(_cache_key=key) + + def set_cache_key_for_path(self, path, key): + return self.execution_options(**{"_cache_key_%s" % path: key}) + + def get_value(cache_key, cache, createfunc): + if cache_key in cache: + return cache[cache_key]() + else: + cache[cache_key] = retval = createfunc().freeze() + return retval() + + s1 = Session(query_cls=CachingQuery) + + @event.listens_for(s1, "do_orm_execute", retval=True) + def do_orm_execute(orm_context): + ckey = None + statement = orm_context.orm_query + for opt in orm_context.user_defined_options: + ckey = opt.get_cache_key(orm_context) + if ckey: + break + else: + if "_cache_key" in statement._execution_options: + ckey = statement._execution_options["_cache_key"] + + if ckey is not None: + return get_value( + ckey, CachingQuery.cache, orm_context.invoke_statement, + ) - return Session(query_cls=CachingQuery) + return s1 def _option_fixture(self): - from sqlalchemy.orm.interfaces import MapperOption + from sqlalchemy.orm.interfaces import UserDefinedOption - class RelationshipCache(MapperOption): + class RelationshipCache(UserDefinedOption): propagate_to_loaders = True - def process_query_conditionally(self, query): - if query._current_path: - query._cache_key = "user7_addresses" - - def _generate_path_cache_key(self, path): - return None + def get_cache_key(self, orm_context): + if orm_context.loader_strategy_path: + return "user7_addresses" + else: + return None return RelationshipCache() @@ -1641,6 +1684,21 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): eq_(q.all(), [User(id=7, addresses=[Address(id=1)])]) + def test_non_baked_tuples(self): + User, Address = self._o2m_fixture() + + sess = self._query_fixture() + q = sess._query_cls + eq_(q.cache, {}) + + q = sess.query(User).filter(User.id == 7).set_cache_key("user7") + + eq_(sess.execute(q).all(), [(User(id=7, addresses=[Address(id=1)]),)]) + + eq_(list(q.cache), ["user7"]) + + eq_(sess.execute(q).all(), [(User(id=7, addresses=[Address(id=1)]),)]) + def test_use_w_baked(self): User, Address = self._o2m_fixture() diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 77b716b0a7..eb9c5147a8 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -15,6 +15,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import util from sqlalchemy.ext.horizontal_shard import ShardedSession +from sqlalchemy.future import select as future_select from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import create_session from sqlalchemy.orm import deferred @@ -27,11 +28,11 @@ from sqlalchemy.pool import SingletonThreadPool from sqlalchemy.sql import operators from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing import provision from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.engines import testing_reaper - # TODO: ShardTest can be turned into a base for further subclasses @@ -190,11 +191,45 @@ class ShardTest(object): sess.close() return sess - def test_roundtrip(self): + def test_get(self): sess = self._fixture_data() - tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() - tokyo.city # reload 'city' attribute on tokyo - sess.expire_all() + tokyo = sess.query(WeatherLocation).get(1) + eq_(tokyo.city, "Tokyo") + + newyork = sess.query(WeatherLocation).get(2) + eq_(newyork.city, "New York") + + t2 = sess.query(WeatherLocation).get(1) + is_(t2, tokyo) + + def test_get_explicit_shard(self): + sess = self._fixture_data() + tokyo = sess.query(WeatherLocation).set_shard("europe").get(1) + is_(tokyo, None) + + newyork = sess.query(WeatherLocation).set_shard("north_america").get(2) + eq_(newyork.city, "New York") + + # now it found it + t2 = sess.query(WeatherLocation).get(1) + eq_(t2.city, "Tokyo") + + def test_query_explicit_shard_via_bind_opts(self): + sess = self._fixture_data() + + stmt = future_select(WeatherLocation).filter(WeatherLocation.id == 1) + + tokyo = ( + sess.execute(stmt, bind_arguments={"shard_id": "asia"}) + .scalars() + .first() + ) + + eq_(tokyo.city, "Tokyo") + + def test_plain_db_lookup(self): + self._fixture_data() + # not sure what this is testing except the fixture data itself eq_( db2.execute(weather_locations.select()).fetchall(), [(1, "Asia", "Tokyo")], @@ -206,12 +241,45 @@ class ShardTest(object): (3, "North America", "Toronto"), ], ) + + def test_plain_core_lookup_w_shard(self): + sess = self._fixture_data() eq_( sess.execute( weather_locations.select(), shard_id="asia" ).fetchall(), [(1, "Asia", "Tokyo")], ) + + def test_roundtrip_future(self): + sess = self._fixture_data() + + tokyo = ( + sess.execute( + future_select(WeatherLocation).filter_by(city="Tokyo") + ) + .scalars() + .one() + ) + eq_(tokyo.city, "Tokyo") + + asia_and_europe = sess.execute( + future_select(WeatherLocation).filter( + WeatherLocation.continent.in_(["Europe", "Asia"]) + ) + ).scalars() + eq_( + {c.city for c in asia_and_europe}, {"Tokyo", "London", "Dublin"}, + ) + + def test_roundtrip(self): + sess = self._fixture_data() + tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() + + eq_(tokyo.city, "Tokyo") + tokyo.city # reload 'city' attribute on tokyo + sess.expire_all() + t = sess.query(WeatherLocation).get(tokyo.id) eq_(t.city, tokyo.city) eq_(t.reports[0].temperature, 80.0) @@ -219,26 +287,23 @@ class ShardTest(object): WeatherLocation.continent == "North America" ) eq_( - set([c.city for c in north_american_cities]), - set(["New York", "Toronto"]), + {c.city for c in north_american_cities}, {"New York", "Toronto"}, ) asia_and_europe = sess.query(WeatherLocation).filter( WeatherLocation.continent.in_(["Europe", "Asia"]) ) eq_( - set([c.city for c in asia_and_europe]), - set(["Tokyo", "London", "Dublin"]), + {c.city for c in asia_and_europe}, {"Tokyo", "London", "Dublin"}, ) # inspect the shard token stored with each instance eq_( - set(inspect(c).key[2] for c in asia_and_europe), - set(["europe", "asia"]), + {inspect(c).key[2] for c in asia_and_europe}, {"europe", "asia"}, ) eq_( - set(inspect(c).identity_token for c in asia_and_europe), - set(["europe", "asia"]), + {inspect(c).identity_token for c in asia_and_europe}, + {"europe", "asia"}, ) newyork = sess.query(WeatherLocation).filter_by(city="New York").one() @@ -324,7 +389,7 @@ class ShardTest(object): canary = [] def load(instance, ctx): - canary.append(ctx.attributes["shard_id"]) + canary.append(ctx.bind_arguments["shard_id"]) event.listen(WeatherLocation, "load", load) sess = self._fixture_data() @@ -571,6 +636,9 @@ class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest): s.commit() def _session_fixture(self, **kw): + # the "fake" key here is to ensure that neither id_chooser + # nor query_chooser are actually used, only shard_chooser + # should be used. return ShardedSession( shards={"main": testing.db}, diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index 9ee5ce2ab6..5494145078 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -4,6 +4,7 @@ from sqlalchemy import func from sqlalchemy import select from sqlalchemy import testing from sqlalchemy import true +from sqlalchemy.future import select as future_select from sqlalchemy.orm import aliased from sqlalchemy.orm import create_session from sqlalchemy.orm import defaultload @@ -209,10 +210,66 @@ class _PolymorphicTestBase(object): ], ) + def test_multi_join_future(self): + sess = create_session() + e = aliased(Person) + c = aliased(Company) + + q = ( + future_select(Company, Person, c, e) + .join(Person, Company.employees) + .join(e, c.employees) + .filter(Person.person_id != e.person_id) + .filter(Person.name == "dilbert") + .filter(e.name == "wally") + ) + + eq_( + sess.execute( + future_select(func.count()).select_from(q.subquery()) + ).scalar(), + 1, + ) + + eq_( + sess.execute(q).all(), + [ + ( + Company(company_id=1, name="MegaCorp, Inc."), + Engineer( + status="regular engineer", + engineer_name="dilbert", + name="dilbert", + company_id=1, + primary_language="java", + person_id=1, + type="engineer", + ), + Company(company_id=1, name="MegaCorp, Inc."), + Engineer( + status="regular engineer", + engineer_name="wally", + name="wally", + company_id=1, + primary_language="c++", + person_id=2, + type="engineer", + ), + ) + ], + ) + def test_filter_on_subclass_one(self): sess = create_session() eq_(sess.query(Engineer).all()[0], Engineer(name="dilbert")) + def test_filter_on_subclass_one_future(self): + sess = create_session() + eq_( + sess.execute(future_select(Engineer)).scalar(), + Engineer(name="dilbert"), + ) + def test_filter_on_subclass_two(self): sess = create_session() eq_(sess.query(Engineer).first(), Engineer(name="dilbert")) @@ -261,6 +318,20 @@ class _PolymorphicTestBase(object): [b1, m1], ) + def test_join_from_polymorphic_nonaliased_one_future(self): + sess = create_session() + eq_( + sess.execute( + future_select(Person) + .join(Person.paperwork) + .filter(Paperwork.description.like("%review%")) + ) + .unique() + .scalars() + .all(), + [b1, m1], + ) + def test_join_from_polymorphic_nonaliased_two(self): sess = create_session() eq_( @@ -306,6 +377,23 @@ class _PolymorphicTestBase(object): [b1, m1], ) + def test_join_from_polymorphic_flag_aliased_one_future(self): + sess = create_session() + + pa = aliased(Paperwork) + eq_( + sess.execute( + future_select(Person) + .order_by(Person.person_id) + .join(Person.paperwork.of_type(pa)) + .filter(pa.description.like("%review%")) + ) + .unique() + .scalars() + .all(), + [b1, m1], + ) + def test_join_from_polymorphic_explicit_aliased_one(self): sess = create_session() pa = aliased(Paperwork) @@ -389,6 +477,23 @@ class _PolymorphicTestBase(object): [b1, m1], ) + def test_join_from_with_polymorphic_nonaliased_one_future(self): + sess = create_session() + + pm = with_polymorphic(Person, [Manager]) + eq_( + sess.execute( + future_select(pm) + .order_by(pm.person_id) + .join(pm.paperwork) + .filter(Paperwork.description.like("%review%")) + ) + .unique() + .scalars() + .all(), + [b1, m1], + ) + def test_join_from_with_polymorphic_nonaliased_two(self): sess = create_session() eq_( @@ -1429,6 +1534,7 @@ class _PolymorphicTestBase(object): .filter(Engineer.primary_language == "java") .statement.scalar_subquery() ) + eq_(sess.query(Person).filter(Person.person_id.in_(subq)).one(), e1) def test_mixed_entities_one(self): diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index 63588d73ee..8e6d73ca44 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -1,19 +1,26 @@ import sqlalchemy as sa from sqlalchemy import ForeignKey +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import select +from sqlalchemy import table from sqlalchemy import testing +from sqlalchemy import true +from sqlalchemy.future import select as future_select from sqlalchemy.orm import backref from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.query import Query from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import mock from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -155,7 +162,92 @@ class BindIntegrationTest(_fixtures.FixtureTest): assert len(session.query(User).filter_by(name="Johnny").all()) == 0 session.close() - def test_bind_arguments(self): + @testing.combinations( + (lambda: {}, "e3"), + (lambda e1: {"bind": e1}, "e1"), + (lambda e1, Address: {"bind": e1, "mapper": Address}, "e1"), + ( + lambda e1, Address: { + "bind": e1, + "clause": Query([Address])._statement_20(), + }, + "e1", + ), + (lambda Address: {"mapper": Address}, "e2"), + (lambda Address: {"clause": Query([Address])._statement_20()}, "e2"), + (lambda addresses: {"clause": select([addresses])}, "e2"), + ( + lambda User, addresses: { + "mapper": User, + "clause": select([addresses]), + }, + "e1", + ), + ( + lambda e2, User, addresses: { + "mapper": User, + "clause": select([addresses]), + "bind": e2, + }, + "e2", + ), + ( + lambda User, Address: { + "clause": future_select(1).join_from(User, Address) + }, + "e1", + ), + ( + lambda User, Address: { + "clause": future_select(1).join_from(Address, User) + }, + "e2", + ), + ( + lambda User: { + "clause": future_select(1).where(User.name == "ed"), + }, + "e1", + ), + (lambda: {"clause": future_select(1)}, "e3"), + (lambda User: {"clause": Query([User])._statement_20()}, "e1"), + (lambda: {"clause": Query([1])._statement_20()}, "e3"), + ( + lambda User: { + "clause": Query([1]).select_from(User)._statement_20() + }, + "e1", + ), + ( + lambda User: { + "clause": Query([1]) + .select_from(User) + .join(User.addresses) + ._statement_20() + }, + "e1", + ), + ( + # forcing the "onclause" argument to be considered + # in visitors.iterate() + lambda User: { + "clause": Query([1]) + .select_from(User) + .join(table("foo"), User.addresses) + ._statement_20() + }, + "e1", + ), + ( + lambda User: { + "clause": future_select(1) + .select_from(User) + .join(User.addresses) + }, + "e1", + ), + ) + def test_get_bind(self, testcase, expected): users, Address, addresses, User = ( self.tables.users, self.classes.Address, @@ -163,33 +255,130 @@ class BindIntegrationTest(_fixtures.FixtureTest): self.classes.User, ) - mapper(User, users) + mapper(User, users, properties={"addresses": relationship(Address)}) mapper(Address, addresses) e1 = engines.testing_engine() e2 = engines.testing_engine() e3 = engines.testing_engine() + testcase = testing.resolve_lambda( + testcase, + User=User, + Address=Address, + e1=e1, + e2=e2, + e3=e3, + addresses=addresses, + ) + sess = Session(e3) sess.bind_mapper(User, e1) sess.bind_mapper(Address, e2) - assert sess.connection().engine is e3 - assert sess.connection(bind=e1).engine is e1 - assert sess.connection(mapper=Address, bind=e1).engine is e1 - assert sess.connection(mapper=Address).engine is e2 - assert sess.connection(clause=addresses.select()).engine is e2 - assert ( - sess.connection(mapper=User, clause=addresses.select()).engine - is e1 + engine = {"e1": e1, "e2": e2, "e3": e3}[expected] + conn = sess.connection(**testcase) + is_(conn.engine, engine) + + sess.close() + + @testing.combinations( + ( + lambda session, Address: session.query(Address), + lambda Address: {"mapper": inspect(Address), "clause": mock.ANY}, + "e2", + ), + (lambda: future_select(1), lambda: {"clause": mock.ANY}, "e3"), + ( + lambda User, Address: future_select(1).join_from(User, Address), + lambda User: {"clause": mock.ANY, "mapper": inspect(User)}, + "e1", + ), + ( + lambda User, Address: future_select(1).join_from(Address, User), + lambda Address: {"clause": mock.ANY, "mapper": inspect(Address)}, + "e2", + ), + ( + lambda User: future_select(1).where(User.name == "ed"), + # no mapper for this one becuase the plugin is not "orm" + lambda User: {"clause": mock.ANY}, + "e1", + ), + ( + lambda User: future_select(1) + .select_from(User) + .where(User.name == "ed"), + lambda User: {"clause": mock.ANY, "mapper": inspect(User)}, + "e1", + ), + ( + lambda User: future_select(User.id), + lambda User: {"clause": mock.ANY, "mapper": inspect(User)}, + "e1", + ), + ) + def test_bind_through_execute( + self, statement, expected_get_bind_args, expected_engine_name + ): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, ) - assert ( - sess.connection( - mapper=User, clause=addresses.select(), bind=e2 - ).engine - is e2 + + mapper(User, users, properties={"addresses": relationship(Address)}) + mapper(Address, addresses) + + e1 = engines.testing_engine() + e2 = engines.testing_engine() + e3 = engines.testing_engine() + + canary = mock.Mock() + + class GetBindSession(Session): + def _connection_for_bind(self, bind, **kw): + canary._connection_for_bind(bind, **kw) + return mock.Mock() + + def get_bind(self, **kw): + canary.get_bind(**kw) + return Session.get_bind(self, **kw) + + sess = GetBindSession(e3) + sess.bind_mapper(User, e1) + sess.bind_mapper(Address, e2) + + lambda_args = dict( + session=sess, + User=User, + Address=Address, + e1=e1, + e2=e2, + e3=e3, + addresses=addresses, + ) + statement = testing.resolve_lambda(statement, **lambda_args) + + expected_get_bind_args = testing.resolve_lambda( + expected_get_bind_args, **lambda_args ) + engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name] + + with mock.patch( + "sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result" + ): + sess.execute(statement) + + eq_( + canary.mock_calls, + [ + mock.call.get_bind(**expected_get_bind_args), + mock.call._connection_for_bind(engine, close_with_result=True), + ], + ) sess.close() def test_bind_arg(self): @@ -495,3 +684,41 @@ class GetBindTest(fixtures.MappedTest): is_(session.get_bind(self.classes.BaseClass), base_class_bind) is_(session.get_bind(self.classes.ConcreteSubClass), concrete_sub_bind) + + @testing.fixture + def two_table_fixture(self): + base_class_bind = Mock(name="base") + concrete_sub_bind = Mock(name="concrete") + + session = self._fixture( + { + self.tables.base_table: base_class_bind, + self.tables.concrete_sub_table: concrete_sub_bind, + } + ) + return session, base_class_bind, concrete_sub_bind + + def test_bind_selectable_table(self, two_table_fixture): + session, base_class_bind, concrete_sub_bind = two_table_fixture + + is_(session.get_bind(clause=self.tables.base_table), base_class_bind) + is_( + session.get_bind(clause=self.tables.concrete_sub_table), + concrete_sub_bind, + ) + + def test_bind_selectable_join(self, two_table_fixture): + session, base_class_bind, concrete_sub_bind = two_table_fixture + + stmt = self.tables.base_table.join( + self.tables.concrete_sub_table, true() + ) + is_(session.get_bind(clause=stmt), base_class_bind) + + def test_bind_selectable_union(self, two_table_fixture): + session, base_class_bind, concrete_sub_bind = two_table_fixture + + stmt = select([self.tables.base_table]).union( + select([self.tables.concrete_sub_table]) + ) + is_(session.get_bind(clause=stmt), base_class_bind) diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 53cb451716..b431ea6b25 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -69,6 +69,7 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): joinedload(User.orders).selectinload(Order.items), defer(User.id), defer("id"), + defer("*"), defer(Address.id), joinedload(User.addresses).defer(Address.id), joinedload(aliased(User).addresses).defer(Address.id), @@ -100,6 +101,7 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): Load(User).defer(User.id), Load(User).subqueryload("addresses"), Load(Address).defer("id"), + Load(Address).defer("*"), Load(aliased(Address)).defer("id"), Load(User).joinedload(User.addresses).defer(Address.id), Load(User).joinedload(User.orders).joinedload(Order.items), @@ -111,6 +113,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): .defer(Item.description), Load(User).defaultload(User.orders).defaultload(Order.items), Load(User).defaultload(User.orders), + Load(Address).raiseload("*"), + Load(Address).raiseload("user"), ), compare_values=True, ) diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 86edf53afc..61df1d277e 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -79,7 +79,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): stmt = select(User).join(Address, User.orders) assert_raises_message( exc.InvalidRequestError, - "Selectable 'addresses' is not derived from 'orders'", + "Join target .*Address.* does not correspond to the right side " + "of join condition User.orders", stmt.compile, ) @@ -371,10 +372,7 @@ class ImplicitWithPolymorphicTest( .order_by(Engineer.person_id) ) - # the ORM has a different column selection than what a purely core - # select does, in terms of engineers.person_id vs. people.person_id - - expected = ( + plain_expected = ( # noqa "SELECT engineers.person_id, people.person_id, people.company_id, " "people.name, " "people.type, engineers.status, " @@ -383,9 +381,23 @@ class ImplicitWithPolymorphicTest( "ON people.person_id = engineers.person_id " "WHERE people.name = :name_1 ORDER BY engineers.person_id" ) + # when we have disambiguating labels turned on + disambiguate_expected = ( # noqa + "SELECT engineers.person_id, people.person_id AS person_id_1, " + "people.company_id, " + "people.name, " + "people.type, engineers.status, " + "engineers.engineer_name, engineers.primary_language " + "FROM people JOIN engineers " + "ON people.person_id = engineers.person_id " + "WHERE people.name = :name_1 ORDER BY engineers.person_id" + ) - self.assert_compile(stmt, expected) - self.assert_compile(q.statement, expected) + # these change based on how we decide to apply labels + # in context.py + self.assert_compile(stmt, disambiguate_expected) + + self.assert_compile(q.statement, disambiguate_expected) def test_select_where_columns_subclass(self): diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index f38f917da0..ae3e18b09a 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -319,13 +319,6 @@ class ExpireTest(_fixtures.FixtureTest): ["addresses"], ) - # in contrast to a regular query with no columns - assert_raises_message( - sa_exc.InvalidRequestError, - "no columns with which to SELECT", - s.query().all, - ) - def test_refresh_cancels_expire(self): users, User = self.tables.users, self.classes.User diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 4b20dfca6e..ce687fdeec 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -3389,7 +3389,7 @@ class ExternalColumnsTest(QueryTest): }, ) - mapper(Address, addresses, properties={"user": relationship(User)}) + mapper(Address, addresses, properties={"user": relationship(User,)}) sess = create_session() @@ -3412,6 +3412,11 @@ class ExternalColumnsTest(QueryTest): Address(id=4, user=User(id=8, concat=16, count=3)), Address(id=5, user=User(id=9, concat=18, count=1)), ] + # TODO: ISSUE: BUG: cached metadata is confusing the user.id + # column here with the anon_1 for some reason, when we + # use compiled cache. this bug may even be present in + # regular master / 1.3. right now the caching of result + # metadata is disabled. eq_(sess.query(Address).all(), address_result) # run the eager version twice to test caching of aliased clauses diff --git a/test/orm/test_loading.py b/test/orm/test_loading.py index 968c8229b7..0967171657 100644 --- a/test/orm/test_loading.py +++ b/test/orm/test_loading.py @@ -78,7 +78,7 @@ class InstancesTest(_fixtures.FixtureTest): ctx.compile_state._entities = [ mock.Mock(row_processor=mock.Mock(side_effect=Exception("boom"))) ] - assert_raises(Exception, loading.instances, q, cursor, ctx) + assert_raises(Exception, loading.instances, cursor, ctx) assert cursor.close.called, "Cursor wasn't closed" def test_row_proc_not_created(self): diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index a090d00442..ab52e5aac6 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -256,7 +256,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): m = mapper(User, users) session = create_session() - session.connection(m) + session.connection(mapper=m) def test_incomplete_columns(self): """Loading from a select which does not contain all columns""" diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 030e6c8704..76706b37b8 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -6,6 +6,7 @@ from sqlalchemy import and_ from sqlalchemy import between from sqlalchemy import bindparam from sqlalchemy import Boolean +from sqlalchemy import case from sqlalchemy import cast from sqlalchemy import collate from sqlalchemy import column @@ -29,11 +30,13 @@ from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import true +from sqlalchemy import type_coerce from sqlalchemy import Unicode from sqlalchemy import union from sqlalchemy import util from sqlalchemy.engine import default from sqlalchemy.ext.compiler import compiles +from sqlalchemy.future import select as future_select from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import backref @@ -47,6 +50,7 @@ from sqlalchemy.orm import lazyload from sqlalchemy.orm import mapper from sqlalchemy.orm import Query from sqlalchemy.orm import relationship +from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import subqueryload from sqlalchemy.orm import synonym @@ -59,7 +63,6 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false -from sqlalchemy.testing import is_not_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import assert_raises @@ -149,6 +152,67 @@ class RowTupleTest(QueryTest): eq_(row.id, 7) eq_(row.uname, "jack") + @testing.combinations( + (lambda s, users: s.query(users),), + (lambda s, User: s.query(User.id, User.name),), + (lambda s, users: s.query(users.c.id, users.c.name),), + (lambda s, users: s.query(users.c.id, users.c.name),), + ) + def test_modern_tuple(self, test_case): + # check we are not getting a LegacyRow back + + User, users = self.classes.User, self.tables.users + + mapper(User, users) + + s = Session() + + q = testing.resolve_lambda(test_case, **locals()) + + row = q.order_by(User.id).first() + assert "jack" in row + + @testing.combinations( + (lambda s, users: s.query(users),), + (lambda s, User: s.query(User.id, User.name),), + (lambda s, users: s.query(users.c.id, users.c.name),), + (lambda s, users: future_select(users),), + (lambda s, User: future_select(User.id, User.name),), + (lambda s, users: future_select(users.c.id, users.c.name),), + ) + def test_modern_tuple_future(self, test_case): + # check we are not getting a LegacyRow back + + User, users = self.classes.User, self.tables.users + + mapper(User, users) + + s = Session() + + q = testing.resolve_lambda(test_case, **locals()) + + row = s.execute(q.order_by(User.id)).first() + assert "jack" in row + + @testing.combinations( + (lambda s, users: select([users]),), + (lambda s, User: select([User.id, User.name]),), + (lambda s, users: select([users.c.id, users.c.name]),), + ) + def test_legacy_tuple_old_select(self, test_case): + + User, users = self.classes.User, self.tables.users + + mapper(User, users) + + s = Session() + + q = testing.resolve_lambda(test_case, **locals()) + + row = s.execute(q.order_by(User.id)).first() + assert "jack" not in row + assert "jack" in tuple(row) + def test_entity_mapping_access(self): User, users = self.classes.User, self.tables.users Address, addresses = self.classes.Address, self.tables.addresses @@ -188,34 +252,6 @@ class RowTupleTest(QueryTest): assert_raises(KeyError, lambda: row._mapping[User.name]) assert_raises(KeyError, lambda: row._mapping[users.c.name]) - def test_deep_entity(self): - users, User = (self.tables.users, self.classes.User) - - mapper(User, users) - - sess = create_session() - bundle = Bundle("b1", User.id, User.name) - subq1 = sess.query(User.id).subquery() - subq2 = sess.query(bundle).subquery() - cte = sess.query(User.id).cte() - ex = sess.query(User).exists() - - is_( - sess.query(subq1)._compile_state()._deep_entity_zero(), - inspect(User), - ) - is_( - sess.query(subq2)._compile_state()._deep_entity_zero(), - inspect(User), - ) - is_( - sess.query(cte)._compile_state()._deep_entity_zero(), - inspect(User), - ) - is_( - sess.query(ex)._compile_state()._deep_entity_zero(), inspect(User), - ) - @testing.combinations( lambda sess, User: ( sess.query(User), @@ -4502,22 +4538,65 @@ class TextTest(QueryTest, AssertsCompiledSQL): self.assert_sql_count(testing.db, go, 1) - def test_other_eager_loads(self): - # this is new in 1.4. with textclause, we build up column loaders - # normally, so that eager loaders also get installed. previously, - # _compile_context() didn't build up column loaders and attempted - # to get them after the fact. + def test_columns_multi_table_uselabels_cols_contains_eager(self): + # test that columns using column._label match, as well as that + # ordering doesn't matter. User = self.classes.User + Address = self.classes.Address s = create_session() q = ( s.query(User) - .from_statement(text("select * from users")) - .options(subqueryload(User.addresses)) + .from_statement( + text( + "select users.name AS users_name, users.id AS users_id, " + "addresses.id AS addresses_id FROM users JOIN addresses " + "ON users.id = addresses.user_id WHERE users.id=8 " + "ORDER BY addresses.id" + ).columns(User.name, User.id, Address.id) + ) + .options(contains_eager(User.addresses)) + ) + + def go(): + r = q.all() + eq_(r[0].addresses, [Address(id=2), Address(id=3), Address(id=4)]) + + self.assert_sql_count(testing.db, go, 1) + + @testing.combinations( + ( + False, + subqueryload, + # sqlite seems happy to interpret the broken SQL and give you the + # correct result somehow, this is a bug in SQLite so don't rely + # upon it doing that + testing.fails("not working yet") + testing.skip_if("sqlite"), + ), + (True, subqueryload, testing.fails("not sure about implementation")), + (False, selectinload), + (True, selectinload), + ) + def test_related_eagerload_against_text(self, add_columns, loader_option): + # new in 1.4. textual selects have columns so subqueryloaders + # and selectinloaders can join onto them. we add columns + # automatiacally to TextClause as well, however subqueryloader + # is not working at the moment due to execution model refactor, + # it creates a subquery w/ adapter before those columns are + # available. this is a super edge case and as we want to rewrite + # the loaders to use select(), maybe we can get it then. + User = self.classes.User + + text_clause = text("select * from users") + if add_columns: + text_clause = text_clause.columns(User.id, User.name) + + s = create_session() + q = ( + s.query(User) + .from_statement(text_clause) + .options(loader_option(User.addresses)) ) - # we can't ORDER BY in this test because SQL server won't let the - # ORDER BY work inside the subqueryload; the test needs to use - # subqueryload (not selectinload) to confirm the feature def go(): eq_(set(q.all()), set(self.static.user_address_result)) @@ -5897,41 +5976,52 @@ class SessionBindTest(QueryTest): yield for call_ in get_bind.mock_calls: if expect_mapped_bind: - is_(call_[1][0], inspect(self.classes.User)) + eq_( + call_, + mock.call( + clause=mock.ANY, mapper=inspect(self.classes.User) + ), + ) else: - is_(call_[1][0], None) - is_not_(call_[2]["clause"], None) + eq_(call_, mock.call(clause=mock.ANY)) def test_single_entity_q(self): User = self.classes.User session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User).all() def test_aliased_entity_q(self): User = self.classes.User u = aliased(User) session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): session.query(u).all() def test_sql_expr_entity_q(self): User = self.classes.User session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User.id).all() def test_sql_expr_subquery_from_entity(self): User = self.classes.User session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): subq = session.query(User.id).subquery() session.query(subq).all() + def test_sql_expr_exists_from_entity(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=True): + subq = session.query(User.id).exists() + session.query(subq).all() + def test_sql_expr_cte_from_entity(self): User = self.classes.User session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): cte = session.query(User.id).cte() subq = session.query(cte).subquery() session.query(subq).all() @@ -5939,7 +6029,7 @@ class SessionBindTest(QueryTest): def test_sql_expr_bundle_cte_from_entity(self): User = self.classes.User session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): cte = session.query(User.id, User.name).cte() subq = session.query(cte).subquery() bundle = Bundle(subq.c.id, subq.c.name) @@ -5948,15 +6038,58 @@ class SessionBindTest(QueryTest): def test_count(self): User = self.classes.User session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): session.query(User).count() + def test_single_col(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=True): + session.query(User.name).all() + + def test_single_col_from_subq(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=True): + subq = session.query(User.id, User.name).subquery() + session.query(subq.c.name).all() + def test_aggregate_fn(self): User = self.classes.User session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): session.query(func.max(User.name)).all() + def test_case(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=True): + session.query(case([(User.name == "x", "C")], else_="W")).all() + + def test_cast(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=True): + session.query(cast(User.name, String())).all() + + def test_type_coerce(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=True): + session.query(type_coerce(User.name, String())).all() + + def test_binary_op(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=True): + session.query(User.name + "x").all() + + def test_boolean_op(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=True): + session.query(User.name == "x").all() + def test_bulk_update_no_sync(self): User = self.classes.User session = Session() @@ -5998,7 +6131,7 @@ class SessionBindTest(QueryTest): column_property(func.coalesce(self.tables.users.c.name, None)), ) session = Session() - with self._assert_bind_args(session): + with self._assert_bind_args(session, expect_mapped_bind=True): session.query(func.max(User.score)).scalar() def test_plain_table(self): @@ -6008,9 +6141,10 @@ class SessionBindTest(QueryTest): with self._assert_bind_args(session, expect_mapped_bind=False): session.query(inspect(User).local_table).all() - def test_plain_table_from_self(self): + def _test_plain_table_from_self(self): User = self.classes.User + # TODO: this test is dumb session = Session() with self._assert_bind_args(session, expect_mapped_bind=False): session.query(inspect(User).local_table).from_self().all() diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index 39e4f89ab8..3d43bb4414 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -2822,7 +2822,7 @@ class SubqueryloadDistinctTest( # Director.photos expect_distinct = director_strategy_level in (True, None) - s = create_session() + s = create_session(testing.db) q = s.query(Movie).options( subqueryload(Movie.director).subqueryload(Director.photos) @@ -2847,7 +2847,9 @@ class SubqueryloadDistinctTest( ) ctx2 = q2._compile_context() - result = s.execute(q2) + stmt = q2.statement + + result = s.connection().execute(stmt) rows = result.fetchall() if expect_distinct: @@ -2876,7 +2878,11 @@ class SubqueryloadDistinctTest( "ON director_1.id = director_photo.director_id" % (" DISTINCT" if expect_distinct else ""), ) - result = s.execute(q3) + + stmt = q3.statement + + result = s.connection().execute(stmt) + rows = result.fetchall() if expect_distinct: eq_( @@ -2911,7 +2917,7 @@ class SubqueryloadDistinctTest( Movie = self.classes.Movie Credit = self.classes.Credit - s = create_session() + s = create_session(testing.db) q = s.query(Credit).options( subqueryload(Credit.movie).subqueryload(Movie.director) @@ -2927,7 +2933,9 @@ class SubqueryloadDistinctTest( ("subqueryload_data", (inspect(Movie), Movie.director.property)) ]["query"] - result = s.execute(q3) + stmt = q3.statement + + result = s.connection().execute(stmt) eq_(result.fetchall(), [(1, "Woody Allen", 1), (1, "Woody Allen", 1)]) diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 6be6898beb..4b4411080a 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -366,7 +366,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): u = User(name="u1") sess.add(u) sess.flush() - c1 = sess.connection(User) + c1 = sess.connection(bind_arguments={"mapper": User}) dbapi_conn = c1.connection assert dbapi_conn.is_valid @@ -383,7 +383,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): assert not dbapi_conn.is_valid eq_(sess.query(User).all(), []) - c2 = sess.connection(User) + c2 = sess.connection(bind_arguments={"mapper": User}) assert not c2.invalidated assert c2.connection.is_valid diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 2a7d09fad4..86221a08f3 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -740,7 +740,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): # outwit the database transaction isolation and SQLA's # expiration at the same time by using different Session on # same transaction - s2 = Session(bind=s.connection(Node)) + s2 = Session(bind=s.connection(mapper=Node)) s2.query(Node).filter(Node.id == n2.id).update({"version_id": 3}) s2.commit() @@ -762,7 +762,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): ), patch.object( config.db.dialect, "supports_sane_multi_rowcount", False ): - s2 = Session(bind=s.connection(Node)) + s2 = Session(bind=s.connection(mapper=Node)) s2.query(Node).filter(Node.id == n2.id).update({"version_id": 3}) s2.commit() @@ -783,7 +783,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): # outwit the database transaction isolation and SQLA's # expiration at the same time by using different Session on # same transaction - s2 = Session(bind=s.connection(Node)) + s2 = Session(bind=s.connection(mapper=Node)) s2.query(Node).filter(Node.id == n1.id).update({"version_id": 3}) s2.commit() diff --git a/test/perf/orm2010.py b/test/perf/orm2010.py index d9efc50a32..5682197705 100644 --- a/test/perf/orm2010.py +++ b/test/perf/orm2010.py @@ -68,7 +68,7 @@ if os.path.exists("orm2010.db"): os.remove("orm2010.db") # use a file based database so that cursor.execute() has some # palpable overhead. -engine = create_engine("sqlite:///orm2010.db") +engine = create_engine("sqlite:///orm2010.db", query_cache_size=100) Base.metadata.create_all(engine) @@ -178,7 +178,7 @@ def run_with_profile(runsnake=False, dump=False): if dump: # stats.sort_stats("nfl") - stats.sort_stats("file", "name") + stats.sort_stats("cumtime", "calls") stats.print_stats() # stats.print_callers() @@ -186,7 +186,7 @@ def run_with_profile(runsnake=False, dump=False): os.system("runsnake %s" % filename) -def run_with_time(): +def run_with_time(factor): import time now = time.time() @@ -222,7 +222,13 @@ if __name__ == "__main__": action="store_true", help="invoke runsnakerun (implies --profile)", ) - + parser.add_argument( + "--factor", + type=int, + default=10, + help="scale factor, a multiple of how many records to work with. " + "defaults to 10", + ) args = parser.parse_args() args.profile = args.profile or args.dump or args.runsnake @@ -230,4 +236,4 @@ if __name__ == "__main__": if args.profile: run_with_profile(runsnake=args.runsnake, dump=args.dump) else: - run_with_time() + run_with_time(args.factor) diff --git a/test/profiles.txt b/test/profiles.txt index 842caf4cd6..b52c99e980 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -165,66 +165,66 @@ test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_ # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation 2.7_sqlite_pysqlite_dbapiunicode_cextensions 45805 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 56605 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation 3.8_sqlite_pysqlite_dbapiunicode_cextensions 49605 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 60905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation 2.7_sqlite_pysqlite_dbapiunicode_cextensions 47405 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 58405 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation 3.8_sqlite_pysqlite_dbapiunicode_cextensions 51005 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 63005 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation 2.7_sqlite_pysqlite_dbapiunicode_cextensions 44305 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 55105 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation 3.8_sqlite_pysqlite_dbapiunicode_cextensions 48105 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 59405 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation 2.7_sqlite_pysqlite_dbapiunicode_cextensions 46305 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 57305 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation 3.8_sqlite_pysqlite_dbapiunicode_cextensions 49905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 61905 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 43405 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 51705 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 46605 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 55405 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 45305 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 53805 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 48205 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 57705 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 42605 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 50905 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 45805 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 54605 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 44505 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 53005 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 47405 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 56905 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle 2.7_sqlite_pysqlite_dbapiunicode_cextensions 42905 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 46405 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle 3.8_sqlite_pysqlite_dbapiunicode_cextensions 45505 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 49505 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle 2.7_sqlite_pysqlite_dbapiunicode_cextensions 44905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 48605 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle 3.8_sqlite_pysqlite_dbapiunicode_cextensions 47205 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 51905 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 43405 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 51705 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 46605 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 55405 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 45305 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 53805 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 48205 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 57705 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 42605 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 50905 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 45805 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 54605 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 44505 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 53005 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 47405 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 56905 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 27805 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 30005 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 30605 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 32705 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 29705 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 32105 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 32205 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 35005 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 27005 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 29205 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 29805 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 31905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_cextensions 28905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 31305 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_cextensions 31405 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 34205 # TEST: test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set @@ -263,66 +263,66 @@ test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branchi # TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching 2.7_sqlite_pysqlite_dbapiunicode_cextensions 404 -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 404 -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching 3.8_sqlite_pysqlite_dbapiunicode_cextensions 410 -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 410 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching 2.7_sqlite_pysqlite_dbapiunicode_cextensions 409 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 409 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching 3.8_sqlite_pysqlite_dbapiunicode_cextensions 415 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 415 # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline 2.7_sqlite_pysqlite_dbapiunicode_cextensions 15169 -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 26174 -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline 3.8_sqlite_pysqlite_dbapiunicode_cextensions 15206 -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 27211 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline 2.7_sqlite_pysqlite_dbapiunicode_cextensions 15186 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 26199 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline 3.8_sqlite_pysqlite_dbapiunicode_cextensions 15220 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 27238 # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 21308 -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 26313 -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols 3.8_sqlite_pysqlite_dbapiunicode_cextensions 21352 -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 27357 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 21337 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 26350 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols 3.8_sqlite_pysqlite_dbapiunicode_cextensions 21378 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 27396 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased 2.7_sqlite_pysqlite_dbapiunicode_cextensions 9603 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 9603 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 9753 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased 3.8_sqlite_pysqlite_dbapiunicode_cextensions 10054 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 10054 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 10204 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain 2.7_sqlite_pysqlite_dbapiunicode_cextensions 3803 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 3803 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 3953 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain 3.8_sqlite_pysqlite_dbapiunicode_cextensions 3804 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 3804 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 3954 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d 2.7_sqlite_pysqlite_dbapiunicode_cextensions 93288 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 93288 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d 3.8_sqlite_pysqlite_dbapiunicode_cextensions 100904 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 100904 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d 2.7_sqlite_pysqlite_dbapiunicode_cextensions 93388 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 93738 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d 3.8_sqlite_pysqlite_dbapiunicode_cextensions 101204 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 101354 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased 2.7_sqlite_pysqlite_dbapiunicode_cextensions 91388 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 91388 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased 3.8_sqlite_pysqlite_dbapiunicode_cextensions 99319 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 99319 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased 2.7_sqlite_pysqlite_dbapiunicode_cextensions 91488 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 91838 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased 3.8_sqlite_pysqlite_dbapiunicode_cextensions 99619 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 99769 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query 2.7_sqlite_pysqlite_dbapiunicode_cextensions 433700 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 433690 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query 3.8_sqlite_pysqlite_dbapiunicode_cextensions 464467 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 464467 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query 2.7_sqlite_pysqlite_dbapiunicode_cextensions 435824 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 437676 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query 3.8_sqlite_pysqlite_dbapiunicode_cextensions 466586 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 468428 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results 2.7_sqlite_pysqlite_dbapiunicode_cextensions 448792 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 463192 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results 3.8_sqlite_pysqlite_dbapiunicode_cextensions 453801 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 472001 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results 2.7_sqlite_pysqlite_dbapiunicode_cextensions 438787 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 455887 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results 3.8_sqlite_pysqlite_dbapiunicode_cextensions 445894 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 463494 # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity @@ -333,24 +333,24 @@ test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_ # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity 2.7_sqlite_pysqlite_dbapiunicode_cextensions 93373 -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 96080 -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity 3.8_sqlite_pysqlite_dbapiunicode_cextensions 94821 -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 98576 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity 2.7_sqlite_pysqlite_dbapiunicode_cextensions 89497 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 92264 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity 3.8_sqlite_pysqlite_dbapiunicode_cextensions 91083 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 94852 # TEST: test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks 2.7_sqlite_pysqlite_dbapiunicode_cextensions 19452 -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 19728 -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks 3.8_sqlite_pysqlite_dbapiunicode_cextensions 20298 -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 20700 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks 2.7_sqlite_pysqlite_dbapiunicode_cextensions 19498 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 19970 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks 3.8_sqlite_pysqlite_dbapiunicode_cextensions 20344 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 20924 # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_load -test.aaa_profiling.test_orm.MergeTest.test_merge_load 2.7_sqlite_pysqlite_dbapiunicode_cextensions 1134 -test.aaa_profiling.test_orm.MergeTest.test_merge_load 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 1157 -test.aaa_profiling.test_orm.MergeTest.test_merge_load 3.8_sqlite_pysqlite_dbapiunicode_cextensions 1168 -test.aaa_profiling.test_orm.MergeTest.test_merge_load 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 1199 +test.aaa_profiling.test_orm.MergeTest.test_merge_load 2.7_sqlite_pysqlite_dbapiunicode_cextensions 1141 +test.aaa_profiling.test_orm.MergeTest.test_merge_load 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 1175 +test.aaa_profiling.test_orm.MergeTest.test_merge_load 3.8_sqlite_pysqlite_dbapiunicode_cextensions 1177 +test.aaa_profiling.test_orm.MergeTest.test_merge_load 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 1221 # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_no_load @@ -361,24 +361,24 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 3.8_sqlite_pysqlite_dba # TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 5437 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 6157 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.8_sqlite_pysqlite_dbapiunicode_cextensions 5795 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 6505 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 5559 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 6299 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.8_sqlite_pysqlite_dbapiunicode_cextensions 5887 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 6667 # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results 2.7_sqlite_pysqlite_dbapiunicode_cextensions 184177 -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 200783 -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results 3.8_sqlite_pysqlite_dbapiunicode_cextensions 189638 -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 207344 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results 2.7_sqlite_pysqlite_dbapiunicode_cextensions 182806 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 199629 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results 3.8_sqlite_pysqlite_dbapiunicode_cextensions 187973 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 206403 # TEST: test.aaa_profiling.test_orm.SessionTest.test_expire_lots -test.aaa_profiling.test_orm.SessionTest.test_expire_lots 2.7_sqlite_pysqlite_dbapiunicode_cextensions 1150 -test.aaa_profiling.test_orm.SessionTest.test_expire_lots 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 1166 -test.aaa_profiling.test_orm.SessionTest.test_expire_lots 3.8_sqlite_pysqlite_dbapiunicode_cextensions 1263 -test.aaa_profiling.test_orm.SessionTest.test_expire_lots 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 1259 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots 2.7_sqlite_pysqlite_dbapiunicode_cextensions 1155 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 1152 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots 3.8_sqlite_pysqlite_dbapiunicode_cextensions 1246 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 1266 # TEST: test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 247332d8ca..d3d21cb0e3 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -681,7 +681,7 @@ class CacheKeyFixture(object): continue eq_(a_key.key, b_key.key) - eq_(hash(a_key), hash(b_key)) + eq_(hash(a_key.key), hash(b_key.key)) for a_param, b_param in zip( a_key.bindparams, b_key.bindparams diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index c580e972d9..efe4d08c53 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -84,6 +84,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import eq_ignore_whitespace from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import mock from sqlalchemy.util import u table1 = table( @@ -5198,9 +5199,16 @@ class ResultMapTest(fixtures.TestBase): wrapped_again = select([c for c in wrapped.c]) - compiled = wrapped_again.compile( - compile_kwargs={"select_wraps_for": stmt} - ) + dialect = default.DefaultDialect() + + with mock.patch.object( + dialect.statement_compiler, + "translate_select_structure", + lambda self, to_translate, **kw: wrapped_again + if to_translate is stmt + else to_translate, + ): + compiled = stmt.compile(dialect=dialect) proxied = [obj[0] for (k, n, obj, type_) in compiled._result_columns] for orig_obj, proxied_obj in zip(orig, proxied): @@ -5245,9 +5253,17 @@ class ResultMapTest(fixtures.TestBase): # so the compiler logic that matches up the "wrapper" to the # "select_wraps_for" can't use inner_columns to match because # these collections are not the same - compiled = wrapped_again.compile( - compile_kwargs={"select_wraps_for": stmt} - ) + + dialect = default.DefaultDialect() + + with mock.patch.object( + dialect.statement_compiler, + "translate_select_structure", + lambda self, to_translate, **kw: wrapped_again + if to_translate is stmt + else to_translate, + ): + compiled = stmt.compile(dialect=dialect) proxied = [obj[0] for (k, n, obj, type_) in compiled._result_columns] for orig_obj, proxied_obj in zip(orig, proxied): diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index 578743750c..2022167234 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -1120,10 +1120,10 @@ class CursorResultTest(fixtures.TablesTest): users = self.tables.users with testing.expect_deprecated( - "Retreiving row values using Column objects " - "with only matching names", - "Using non-integer/slice indices on Row is " - "deprecated and will be removed in version 2.0", + # "Retreiving row values using Column objects " + # "with only matching names", + # "Using non-integer/slice indices on Row is " + # "deprecated and will be removed in version 2.0", ): # this will create column() objects inside # the select(), these need to match on name anyway @@ -1137,14 +1137,14 @@ class CursorResultTest(fixtures.TablesTest): r._keymap.pop(users.c.user_id) # reset lookup with testing.expect_deprecated( - "Retreiving row values using Column objects " - "with only matching names" + # "Retreiving row values using Column objects " + # "with only matching names" ): eq_(r._mapping[users.c.user_id], 2) with testing.expect_deprecated( - "Retreiving row values using Column objects " - "with only matching names" + # "Retreiving row values using Column objects " + # "with only matching names" ): eq_(r._mapping[users.c.user_name], "jack") diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 6c83697dcc..0eff94635c 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1856,7 +1856,34 @@ class KeyTargetingTest(fixtures.TablesTest): is_( existing_metadata._keymap[k], adapted_metadata._keymap[other_k] ) - return stmt1, existing_metadata, stmt2, adapted_metadata + + @testing.combinations( + _adapt_result_columns_fixture_one, + _adapt_result_columns_fixture_two, + _adapt_result_columns_fixture_three, + _adapt_result_columns_fixture_four, + argnames="stmt_fn", + ) + def test_adapt_result_columns_from_cache(self, connection, stmt_fn): + stmt1 = stmt_fn(self) + stmt2 = stmt_fn(self) + + cache = {} + result = connection._execute_20( + stmt1, + execution_options={"compiled_cache": cache, "future_result": True}, + ) + result.close() + assert cache + + result = connection._execute_20( + stmt2, + execution_options={"compiled_cache": cache, "future_result": True}, + ) + + row = result.first() + for col in stmt2.selected_columns: + assert col in row._mapping class PositionalTextTest(fixtures.TablesTest):