+-*- coding: utf-8; fill-column: 68 -*-
+
=======
CHANGES
=======
+user_defined_state
+==================
+
+ - The "__init__" trigger/decorator added by mapper now attempts
+ to exactly mirror the argument signature of the original
+ __init__. The pass-through for '_sa_session' is no longer
+ implicit- you must allow for this keyword argument in your
+ constructor.
+
+ - ClassState is renamed to ClassManager.
+
+ - Classes may supply their own InstrumentationManager by
+ providing a __sa_instrumentation_manager__ property.
+
+ - Custom instrumentation may use any mechanism to associate a
+ ClassManager with a class and an InstanceState with an
+ instance. Attributes on those objects are still the default
+ association mechanism used by SQLAlchemy's native
+ instrumentation.
+
+ - Moved entity_name, _sa_session_id, and _instance_key from the
+ instance object to the instance state. These values are still
+ available in the old way, which is now deprecated, using
+ descriptors attached to the class. A deprecation warning will
+ be issued when accessed.
+
+ - attribute savepoint/rollback capability has been added. For
+ starters, this takes effect within the flush() call, so that
+ attribute changes which occur within flush() are rolled back
+ when the flush fails. Since it's primarily new primary key
+ values that get assigned within flush(), expiring those
+ attributes is not an option. The next place we might use
+ savepoints is within SAVEPOINT transactions, since rolling
+ back to a savepoint is a transaction-contained operation.
+
+ - The _prepare_instrumentation alias for prepare_instrumentation
+ has been removed.
+
+ - sqlalchemy.exceptions has been renamed to sqlalchemy.exc. The
+ module may be imported under either name.
+
+ - ORM-related exceptions are now defined in sqlalchemy.orm.exc.
+ ConcurrentModificationError, FlushError, and
+ UnmappedColumnError compatibility aliases are installed in
+ sqlalchemy.exc during the import of sqlalchemy.orm.
+
+ - sqlalchemy.logging has been renamed to sqlalchemy.log.
+
+ - The transitional sqlalchemy.log.SADeprecationWarning alias for
+ the warning's definition in sqlalchemy.exc has been removed.
+
+ - exc.AssertionError has been removed and usage replaced with
+ Python's built-in AssertionError.
+
+ - The behavior of MapperExtensions attached to multiple,
+ entity_name= primary mappers for a single class has been
+ altered. The first mapper() defined for a class is the only
+ mapper eligible for the MapperExtension 'instrument_class',
+ 'init_instance' and 'init_failed' events. This is backwards
+ incompatible; previously the extensions of last mapper defined
+ would receive these events.
+
+
0.4.6
=====
- orm
- - A fix to the recent relation() refactoring which fixes
+ - Fix to the recent relation() refactoring which fixes
exotic viewonly relations which join between local and
remote table multiple times, with a common column shared
between the joins.
- Also re-established viewonly relation() configurations
that join across multiple tables.
- - contains_eager(), the hot function of the week, suppresses
- the eager loader's own generation of the LEFT OUTER JOIN,
- so that it is reasonable to use any Query, not just those
- which use from_statement().
-
- - Added an experimental relation() flag to help with
+ - Added experimental relation() flag to help with
primaryjoins across functions, etc.,
_local_remote_pairs=[tuples]. This complements a complex
primaryjoin condition allowing you to provide the
Query.order_by() if clause adaption had taken place.
[ticket:1027]
- - Removed an ancient assertion that mapped selectables
- require "alias names" - the mapper creates its own alias
- now if none is present. Though in this case you need to
- use the class, not the mapped selectable, as the source of
+ - Removed ancient assertion that mapped selectables require
+ "alias names" - the mapper creates its own alias now if
+ none is present. Though in this case you need to use the
+ class, not the mapped selectable, as the source of column
+ attributes - so a warning is still issued.
+
+ - fixes to the "exists" function involving inheritance (any(), has(),
+ ~contains()); the full target join will be rendered into the
+ EXISTS clause for relations that link to subclasses.
+
+ - restored usage of append_result() extension method for primary
+ query rows, when the extension is present and only a single-
+ entity result is being returned.
+
+ - Also re-established viewonly relation() configurations that
+ join across multiple tables.
+
+ - removed ancient assertion that mapped selectables require
+ "alias names" - the mapper creates its own alias now if
+ none is present. Though in this case you need to use
+ the class, not the mapped selectable, as the source of
column attributes - so a warning is still issued.
-
- - Fixes to the "exists" function involving inheritance
- (any(), has(), ~contains()); the full target join will be
- rendered into the EXISTS clause for relations that link to
- subclasses.
-
- - Restored usage of append_result() extension method for
- primary query rows, when the extension is present and only
- a single- entity result is being returned.
-
- - Fixed Class.collection==None for m2m relationships
- [ticket:4213]
-
- - Refined mapper._save_obj() which was unnecessarily calling
+
+ - refined mapper._save_obj() which was unnecessarily calling
__ne__() on scalar values during flush [ticket:1015]
-
- - Added a feature to eager loading whereby subqueries set as
- column_property() with explicit label names (which is not
- necessary, btw) will have the label anonymized when the
- instance is part of the eager join, to prevent conflicts
- with a subquery or column of the same name on the parent
- object. [ticket:1019]
-
- - Same as [ticket:1019] but repaired the non-labeled use
- case [ticket:1022]
-
- - Adjusted class-member inspection during attribute and
- collection instrumentation that could be problematic when
- integrating with other frameworks.
-
- - Fixed duplicate append event emission on repeated
- instrumented set.add() operations.
-
+
+ - added a feature to eager loading whereby subqueries set
+ as column_property() with explicit label names (which is not
+ necessary, btw) will have the label anonymized when
+ the instance is part of the eager join, to prevent
+ conflicts with a subquery or column of the same name
+ on the parent object. [ticket:1019]
+
- set-based collections |=, -=, ^= and &= are stricter about
their operands and only operate on sets, frozensets or
subclasses of the collection type. Previously, they would
a simple way to place dictionary behavior on top of
a dynamic_loader.
+- declarative extension
+ - Joined table inheritance mappers use a slightly relaxed
+ function to create the "inherit condition" to the parent
+ table, so that other foreign keys to not-yet-declared
+ Table objects don't trigger an error.
+
+ - fixed reentrant mapper compile hang when
+ a declared attribute is used within ForeignKey,
+ ie. ForeignKey(MyOtherClass.someattribute)
+
- sql
- Added COLLATE support via the .collate(<collation>)
expression operator and collate(<expr>, <collation>) sql
- Fixed bug with union() when applied to non-Table connected
select statements
- - Improved behavior of text() expressions when used as FROM
- clauses, such as select().select_from(text("sometext"))
+ - improved behavior of text() expressions when used as
+ FROM clauses, such as select().select_from(text("sometext"))
[ticket:1014]
- - Column.copy() respects the value of "autoincrement", fixes
- usage with Migrate [ticket:1021]
-
+ - Column.copy() respects the value of "autoincrement",
+ fixes usage with Migrate [ticket:1021]
+
- engines
- Pool listeners can now be provided as a dictionary of
callables or a (possibly partial) duck-type of
PoolListener, your choice.
+
+ - added "rollback_returned" option to Pool which will
+ disable the rollback() issued when connections are
+ returned. This flag is only safe to use with a database
+ which does not support transactions (i.e. MySQL/MyISAM).
- - Added "reset_on_return" option to Pool which will disable
- the database state cleanup step (e.g. issuing a
- rollback()) when connections are returned to the pool.
-
--extensions
+-ext
- set-based association proxies |=, -=, ^= and &= are
stricter about their operands and only operate on sets,
frozensets or other association proxies. Previously, they
would accept any duck-typed set.
-- declarative extension
- - Joined table inheritance mappers use a slightly relaxed
- function to create the "inherit condition" to the parent
- table, so that other foreign keys to not-yet-declared
- Table objects don't trigger an error.
-
- - Fixed re-entrant mapper compile hang when a declared
- attribute is used within ForeignKey,
- i.e. ForeignKey(MyOtherClass.someattribute)
-
- mssql
- Added "odbc_autotranslate" parameter to engine / dburi
parameters. Any given string will be passed through to the
This should obviate the need of adding a myriad of ODBC
options in the future.
+
- firebird
- Handle the "SUBSTRING(:string FROM :start FOR :length)"
builtin.
-
0.4.5
=====
- orm
- Added comparable_property(), adds query Comparator
behavior to regular, unmanaged Python properties
- - The functionality of query.with_polymorphic() has been
- added to mapper() as a configuration option.
+ - the functionality of query.with_polymorphic() has
+ been added to mapper() as a configuration option.
It's set via several forms:
-
with_polymorphic='*'
with_polymorphic=[mappers]
with_polymorphic=('*', selectable)
with_polymorphic=([mappers], selectable)
-
- This controls the default polymorphic loading strategy for
- inherited mappers. When a selectable is not given, outer
- joins are created for all joined-table inheriting mappers
- requested. Note that the auto-create of joins is not
- compatible with concrete table inheritance.
-
- The existing select_table flag on mapper() is now
- deprecated and is synonymous with:
-
- with_polymorphic('*', select_table).
-
- Note that the underlying "guts" of select_table have been
- completely removed and replaced with the newer, more
- flexible approach.
-
- The new approach also automatically allows eager loads to
- work for subclasses, if they are present, for example
-
+
+ This controls the default polymorphic loading strategy
+ for inherited mappers. When a selectable is not given,
+ outer joins are created for all joined-table inheriting
+ mappers requested. Note that the auto-create of joins
+ is not compatible with concrete table inheritance.
+
+ The existing select_table flag on mapper() is now
+ deprecated and is synonymous with
+ with_polymorphic('*', select_table). Note that the
+ underlying "guts" of select_table have been
+ completely removed and replaced with the newer,
+ more flexible approach.
+
+ The new approach also automatically allows eager loads
+ to work for subclasses, if they are present, for
+ example
sess.query(Company).options(
eagerload_all(
[Company.employees.of_type(Engineer), 'machines']
))
-
to load Company objects, their employees, and the
'machines' collection of employees who happen to be
Engineers. A "with_polymorphic" Query option should be
introduced soon as well which would allow per-Query
control of with_polymorphic() on relations.
-
- - Added two "experimental" features to Query, "experimental"
- in that their specific name/behavior is not carved in
- stone just yet: _values() and _from_self(). We'd like
- feedback on these.
-
- - _values(*columns) is given a list of column expressions,
- and returns a new Query that only returns those
- columns. When evaluated, the return value is a list of
- tuples just like when using add_column() or
- add_entity(), the only difference is that "entity zero",
- i.e. the mapped class, is not included in the
- results. This means it finally makes sense to use
- group_by() and having() on Query, which have been
- sitting around uselessly until now.
-
+
+ - added two "experimental" features to Query,
+ "experimental" in that their specific name/behavior
+ is not carved in stone just yet: _values() and
+ _from_self(). We'd like feedback on these.
+
+ - _values(*columns) is given a list of column
+ expressions, and returns a new Query that only
+ returns those columns. When evaluated, the return
+ value is a list of tuples just like when using
+ add_column() or add_entity(), the only difference is
+ that "entity zero", i.e. the mapped class, is not
+ included in the results. This means it finally makes
+ sense to use group_by() and having() on Query, which
+ have been sitting around uselessly until now.
+
A future change to this method may include that its
ability to join, filter and allow other options not
related to a "resultset" are removed, so the feedback
we're looking for is how people want to use
- _values()...i.e. at the very end, or do people prefer to
- continue generating after it's called.
-
- - _from_self() compiles the SELECT statement for the Query
- (minus any eager loaders), and returns a new Query that
- selects from that SELECT. So basically you can query
- from a Query without needing to extract the SELECT
- statement manually. This gives meaning to operations
- like query[3:5]._from_self().filter(some
- criterion). There's not much controversial here except
- that you can quickly create highly nested queries that
- are less efficient, and we want feedback on the naming
- choice.
-
- - query.order_by() and query.group_by() will accept multiple
- arguments using *args (like select() already does).
-
+ _values()...i.e. at the very end, or do people prefer
+ to continue generating after it's called.
+
+ - _from_self() compiles the SELECT statement for the
+ Query (minus any eager loaders), and returns a new
+ Query that selects from that SELECT. So basically you
+ can query from a Query without needing to extract the
+ SELECT statement manually. This gives meaning to
+ operations like query[3:5]._from_self().filter(some
+ criterion). There's not much controversial here
+ except that you can quickly create highly nested
+ queries that are less efficient, and we want feedback
+ on the naming choice.
+
+ - query.order_by() and query.group_by() will accept
+ multiple arguments using *args (like select()
+ already does).
+
- Added some convenience descriptors to Query:
query.statement returns the full SELECT construct,
query.whereclause returns just the WHERE part of the
- Delete cascade with delete-orphan will delete orphans
whether or not it remains attached to its also-deleted
parent.
-
- - delete-orphan casacde is properly detected on
- relations that are present on superclasses when using
- inheritance.
+
+ - delete-orphan casacde is properly detected on relations
+ that are present on superclasses when using inheritance.
- Fixed order_by calculation in Query to properly alias
mapper-config'ed order_by when using select_from()
iterative to support deep object graphs.
- sql
- - Schema-qualified tables now will place the schemaname
+ - schema-qualified tables now will place the schemaname
ahead of the tablename in all column expressions as well
as when generating column labels. This prevents cross-
schema name collisions in all cases [ticket:999]
-
- - Can now allow selects which correlate all FROM clauses and
- have no FROM themselves. These are typically used in a
- scalar context, i.e. SELECT x, (SELECT x WHERE y) FROM
- table. Requires explicit correlate() call.
-
+
+ - can now allow selects which correlate all FROM clauses
+ and have no FROM themselves. These are typically
+ used in a scalar context, i.e. SELECT x, (SELECT x WHERE y)
+ FROM table. Requires explicit correlate() call.
+
- 'name' is no longer a required constructor argument for
Column(). It (and .key) may now be deferred until the
column is added to a Table.
SA will force explicit usage of either text() or
literal().
+- oracle
+ - The "owner" keyword on Table is now deprecated, and is
+ exactly synonymous with the "schema" keyword. Tables can
+ now be reflected with alternate "owner" attributes,
+ explicitly stated on the Table object or not using
+ "schema".
+
+ - All of the "magic" searching for synonyms, DBLINKs etc.
+ during table reflection are disabled by default unless you
+ specify "oracle_resolve_synonyms=True" on the Table
+ object. Resolving synonyms necessarily leads to some
+ messy guessing which we'd rather leave off by default.
+ When the flag is set, tables and related tables will be
+ resolved against synonyms in all cases, meaning if a
+ synonym exists for a particular table, reflection will use
+ it when reflecting related tables. This is stickier
+ behavior than before which is why it's off by default.
+
- declarative extension
- The "synonym" function is now directly usable with
"declarative". Pass in the decorated property using the
- inheritance in declarative can be disabled when sending
"inherits=None" to __mapper_args__.
- - declarative_base() takes optional kwarg "mapper", which
- is any callable/class/method that produces a mapper, such
- as declarative_base(mapper=scopedsession.mapper). This
- property can also be set on individual declarative
+ - declarative_base() takes optional kwarg "mapper", which
+ is any callable/class/method that produces a mapper,
+ such as declarative_base(mapper=scopedsession.mapper).
+ This property can also be set on individual declarative
classes using the "__mapper_cls__" property.
- postgres
behavior than before which is why it's off by default.
- mssql
- - Reflected tables will now automatically load other tables
+ - Reflected tables will now automatically load other tables
which are referenced by Foreign keys in the auto-loaded
- table, [ticket:979].
+ table, [ticket:979].
- - Added executemany check to skip identity fetch,
- [ticket:916].
+ - Added executemany check to skip identity fetch, [ticket:916].
- Added stubs for small date type, [ticket:884]
- - Added a new 'driver' keyword parameter for the pyodbc
- dialect. Will substitute into the ODBC connection string
- if given, defaults to 'SQL Server'.
+ - Added a new 'driver' keyword parameter for the pyodbc dialect.
+ Will substitute into the ODBC connection string if given,
+ defaults to 'SQL Server'.
- Added a new 'max_identifier_length' keyword parameter for
the pyodbc dialect.
--- /dev/null
+Trunk of SQLAlchemy is now on the 0.5 version. This version
+removes many things which were deprecated in 0.4 and therefore
+is not backwards compatible with all 0.4 appliactions.
+
+A work in progress describing the changes from 0.4 is at:
+
+ http://www.sqlalchemy.org/trac/wiki/05Migration
+
+To continue working with the current development revision of
+version 0.4, switch this working copy to the 0.4 maintenance branch:
+
+ svn switch http://svn.sqlalchemy.org/sqlalchemy/branches/rel_0_4
+
+
--- /dev/null
+from sqlalchemy.orm import attributes
+class Foo(object):pass
+attributes.register_class(Foo)
+attributes.register_attribute(Foo, 'x', uselist=False, useobject=False, mutable_scalars=True, copy_function=lambda x:x.copy())
+
+f = Foo()
+f._foostate.set_savepoint()
+print f._foostate.get_history('x')
+
+f.x = {'1':15}
+
+
+print f._foostate.get_history('x')
+f._foostate.commit_all()
+
+print f._foostate.get_history('x')
+
+f.x['2'] = 40
+print f._foostate.get_history('x')
+
+f._foostate.rollback()
+
+print f._foostate.get_history('x')
+
+#import pdb
+#pdb.Pdb().break_here()
+
+print f.x
+f.x['2'] = 40
+print f._foostate.get_history('x')
+
* SQLite: [pysqlite](http://initd.org/tracker/pysqlite), [sqlite3](http://docs.python.org/lib/module-sqlite3.html) (included with Python 2.5 or greater)
* MySQL: [MySQLdb](http://sourceforge.net/projects/mysql-python)
* Oracle: [cx_Oracle](http://www.cxtools.net/default.aspx?nav=home)
-* MS-SQL: [pyodbc](http://pyodbc.sourceforge.net/) (recommended), [adodbapi](http://adodbapi.sourceforge.net/) or [pymssql](http://pymssql.sourceforge.net/)
+* MS-SQL, MSAccess: [pyodbc](http://pyodbc.sourceforge.net/) (recommended), [adodbapi](http://adodbapi.sourceforge.net/) or [pymssql](http://pymssql.sourceforge.net/)
* Firebird: [kinterbasdb](http://kinterbasdb.sourceforge.net/)
* Informix: [informixdb](http://informixdb.sourceforge.net/)
+* DB2/Informix IDS: [ibm-db](http://code.google.com/p/ibm-db/)
+* Sybase: TODO
+* MAXDB: TODO
### Checking the Installed SQLAlchemy Version
-This documentation covers SQLAlchemy version 0.4. If you're working on a system that already has SQLAlchemy installed, check the version from your Python prompt like this:
+This documentation covers SQLAlchemy version 0.5. If you're working on a system that already has SQLAlchemy installed, check the version from your Python prompt like this:
{python}
>>> import sqlalchemy
>>> sqlalchemy.__version__ # doctest: +SKIP
- 0.4.0
+ 0.5.0
-## 0.3 to 0.4 Migration {@name=migration}
+## 0.4 to 0.5 Migration {@name=migration}
-From version 0.3 to version 0.4 of SQLAlchemy, some conventions have changed. Most of these conventions are available in the most recent releases of the 0.3 series starting with version 0.3.9, so that you can make a 0.3 application compatible with 0.4 in most cases.
-
-This section will detail only those things that have changed in a backwards-incompatible manner. For a full overview of everything that's new and changed, see [WhatsNewIn04](http://www.sqlalchemy.org/trac/wiki/WhatsNewIn04).
-
-### ORM Package is now sqlalchemy.orm {@name=imports}
-
-All symbols related to the SQLAlchemy Object Relational Mapper, i.e. names like `mapper()`, `relation()`, `backref()`, `create_session()` `synonym()`, `eagerload()`, etc. are now only in the `sqlalchemy.orm` package, and **not** in `sqlalchemy`. So if you were previously importing everything on an asterisk:
-
- {python}
- from sqlalchemy import *
-
-You should now import separately from orm:
-
- {python}
- from sqlalchemy import *
- from sqlalchemy.orm import *
-
-Or more commonly, just pull in the names you'll need:
-
- {python}
- from sqlalchemy import create_engine, MetaData, Table, Column, types
- from sqlalchemy.orm import mapper, relation, backref, create_session
-
-### BoundMetaData is now MetaData {@name=metadata}
-
-The `BoundMetaData` name is removed. Now, you just use `MetaData`. Additionally, the `engine` parameter/attribute is now called `bind`, and `connect()` is deprecated:
-
- {python}
- # plain metadata
- meta = MetaData()
-
- # metadata bound to an engine
- meta = MetaData(engine)
-
- # bind metadata to an engine later
- meta.bind = engine
-
-Additionally, `DynamicMetaData` is now known as `ThreadLocalMetaData`.
-
-### "Magic" Global MetaData removed {@name=global}
-
-There was an old way to specify `Table` objects using an implicit, global `MetaData` object. To do this you'd omit the second positional argument, and specify `Table('tablename', Column(...))`. This no longer exists in 0.4 and the second `MetaData` positional argument is required, i.e. `Table('tablename', meta, Column(...))`.
-
-### Some existing select() methods become generative {@name=generative}
-
-The methods `correlate()`, `order_by()`, and `group_by()` on the `select()` construct now return a **new** select object, and do not change the original one. Additionally, the generative methods `where()`, `column()`, `distinct()`, and several others have been added:
-
- {python}
- s = table.select().order_by(table.c.id).where(table.c.x==7)
- result = engine.execute(s)
-
-### collection_class behavior is changed {@name=collection}
-
-If you've been using the `collection_class` option on `mapper()`, the requirements for instrumented collections have changed. For an overview, see [advdatamapping_relation_collections](rel:advdatamapping_relation_collections).
-
-### All "engine", "bind_to", "connectable" Keyword Arguments Changed to "bind" {@name=bind}
-
-This is for create/drop statements, sessions, SQL constructs, metadatas:
-
- {python}
- myengine = create_engine('sqlite://')
-
- meta = MetaData(myengine)
-
- meta2 = MetaData()
- meta2.bind = myengine
-
- session = create_session(bind=myengine)
-
- statement = select([table], bind=myengine)
-
- meta.create_all(bind=myengine)
-
-### All "type" Keyword Arguments Changed to "type_" {@name=type}
-
-This mostly applies to SQL constructs where you pass a type in:
-
- {python}
- s = select([mytable], mytable.c.x=bindparam(y, type_=DateTime))
-
- func.now(type_=DateTime)
-
-### Mapper Extensions must return EXT_CONTINUE to continue execution to the next mapper
-
-If you extend the mapper, the methods in your mapper extension must return EXT_CONTINUE to continue executing additional mappers.
+Notes on what's changed from 0.4 to 0.5 is available on the SQLAlchemy wiki at [05Migration](http://www.sqlalchemy.org/trac/wiki/05Migration).
{python}
mapper(Address, addresses_table)
mapper(User, users_table, properties={
- 'addresses' : relation(Address,
- primaryjoin=users_table.c.user_id==addresses_table.c.user_id,
- foreign_keys=[addresses_table.c.user_id])
+ 'addresses' : relation(Address, primaryjoin=
+ users_table.c.user_id==addresses_table.c.user_id,
+ foreign_keys=[addresses_table.c.user_id])
})
##### Building Query-Enabled Properties {@name=properties}
There are two other loader strategies available, **dynamic loading** and **no loading**; these are described in [advdatamapping_relation_largecollections](rel:advdatamapping_relation_largecollections).
-##### Combining Eager Loads with Statement/Result Set Queries
+##### Routing Explicit Joins/Statements into Eagerly Loaded Collections {@name=containseager}
-When full statement or result-set loads are used with `Query`, SQLAlchemy does not affect the SQL query itself, and therefore has no way of tacking on its own `LEFT [OUTER] JOIN` conditions that are normally used to eager load relationships. If the query being constructed is created in such a way that it returns rows not just from a parent table (or tables) but also returns rows from child tables, the result-set mapping can be notified as to which additional properties are contained within the result set. This is done using the `contains_eager()` query option, which specifies the name of the relationship to be eagerly loaded.
+When full statement loads are used with `Query`, the user defined SQL is used verbatim and the `Query` does not play any role in generating it. In this scenario, if eager loading is desired, the `Query` should be informed as to what collections should also be loaded from the result set. Similarly, Queries which compile their statement in the usual way may also have user-defined joins built in which are synonymous with what eager loading would normally produce, and it improves performance to utilize those same JOINs for both purposes, instead of allowing the eager load mechanism to generate essentially the same JOIN redundantly. Yet another use case for such a feature is a Query which returns instances with a filtered view of their collections loaded, in which case the default eager load mechanisms need to be bypassed.
+
+The single option `Query` provides to control this is the `contains_eager()` option, which specifies the path of a single relationship to be eagerly loaded. Like all relation-oriented options, it takes a string or Python descriptor as an argument. Below it's used with a `from_statement` load:
{python}
# mapping is the users->addresses mapping
})
# define a query on USERS with an outer join to ADDRESSES
- statement = users_table.outerjoin(addresses_table).select(use_labels=True)
+ statement = users_table.outerjoin(addresses_table).select().apply_labels()
# construct a Query object which expects the "addresses" results
query = session.query(User).options(contains_eager('addresses'))
# get results normally
r = query.from_statement(statement)
+It works just as well with an inline `Query.join()` or `Query.outerjoin()`:
+
+ {python}
+ session.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses)).all()
+
If the "eager" portion of the statement is "aliased", the `alias` keyword argument to `contains_eager()` may be used to indicate it. This is a string alias name or reference to an actual `Alias` object:
{python}
- # use an alias of the addresses table
- adalias = addresses_table.alias('adalias')
+ # use an alias of the Address entity
+ adalias = aliased(Address)
- # define a query on USERS with an outer join to adalias
- statement = users_table.outerjoin(adalias).select(use_labels=True)
-
# construct a Query object which expects the "addresses" results
- query = session.query(User).options(contains_eager('addresses', alias=adalias))
+ query = session.query(User).outerjoin((adalias, User.addresses)).options(contains_eager(User.addresses, alias=adalias))
# get results normally
- {sql}r = query.from_statement(statement).all()
+ {sql}r = query.all()
SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, adalias.address_id AS adalias_address_id,
adalias.user_id AS adalias_user_id, adalias.email_address AS adalias_email_address, (...other columns...)
- FROM users LEFT OUTER JOIN email_addresses AS adalias ON users.user_id = adalias.user_id
+ FROM users LEFT OUTER JOIN email_addresses AS email_addresses_1 ON users.user_id = email_addresses_1.user_id
+
+The path given as the argument to `contains_eager()` needs to be a full path from the starting entity. For example if we were loading `Users->orders->Order->items->Item`, the string version would look like:
-In the case that the main table itself is also aliased, the `contains_alias()` option can be used:
+ {python}
+ query(User).options(contains_eager('orders', 'items'))
+
+The descriptor version like:
+
+ {python}
+ query(User).options(contains_eager(User.orders, Order.items))
+
+A variant on `contains_eager()` is the `contains_alias()` option, which is used in the rare case that the parent object is loaded from an alias within a user-defined SELECT statement:
{python}
# define an aliased UNION called 'ulist'
statement = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist')
# add on an eager load of "addresses"
- statement = statement.outerjoin(addresses).select(use_labels=True)
+ statement = statement.outerjoin(addresses).select().apply_labels()
# create query, indicating "ulist" is an alias for the main table, "addresses" property should
# be eager loaded
- query = create_session().query(User).options(contains_alias('ulist'), contains_eager('addresses'))
+ query = session.query(User).options(contains_alias('ulist'), contains_eager('addresses'))
# results
r = query.from_statement(statement)
-[alpha_api]: javascript:alphaApi()
-[alpha_implementation]: javascript:alphaImplementation()
-
Object Relational Tutorial {@name=datamapping}
============
-In this tutorial we will cover a basic SQLAlchemy object-relational mapping scenario, where we store and retrieve Python objects from a database representation. The database schema will begin with one table, and will later develop into several. The tutorial is in doctest format, meaning each `>>>` line represents something you can type at a Python command prompt, and the following text represents the expected return value. The tutorial has no prerequisites.
+In this tutorial we will cover a basic SQLAlchemy object-relational mapping scenario, where we store and retrieve Python objects from a database representation. The tutorial is in doctest format, meaning each `>>>` line represents something you can type at a Python command prompt, and the following text represents the expected return value.
## Version Check
-A quick check to verify that we are on at least **version 0.4** of SQLAlchemy:
+A quick check to verify that we are on at least **version 0.5** of SQLAlchemy:
{python}
>>> import sqlalchemy
>>> sqlalchemy.__version__ # doctest:+SKIP
- 0.4.0
+ 0.5.0
## Connecting
-For this tutorial we will use an in-memory-only SQLite database. This is an easy way to test things without needing to have an actual database defined anywhere. To connect we use `create_engine()`:
+For this tutorial we will use an in-memory-only SQLite database. To connect we use `create_engine()`:
{python}
>>> from sqlalchemy import create_engine
## Define and Create a Table {@name=tables}
-Next we want to tell SQLAlchemy about our tables. We will start with just a single table called `users`, which will store records for the end-users using our application (lets assume it's a website). We define our tables all within a catalog called `MetaData`, using the `Table` construct, which resembles regular SQL CREATE TABLE syntax:
+Next we want to tell SQLAlchemy about our tables. We will start with just a single table called `users`, which will store records for the end-users using our application (lets assume it's a website). We define our tables within a catalog called `MetaData`, using the `Table` construct, which is used in a manner similar to SQL's CREATE TABLE syntax:
{python}
>>> from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey
>>> metadata = MetaData()
>>> users_table = Table('users', metadata,
... Column('id', Integer, primary_key=True),
- ... Column('name', String(40)),
- ... Column('fullname', String(100)),
- ... Column('password', String(15))
+ ... Column('name', String),
+ ... Column('fullname', String),
+ ... Column('password', String)
... )
-All about how to define `Table` objects, as well as how to create them from an existing database automatically, is described in [metadata](rel:metadata).
+All about how to define `Table` objects, as well as how to load their definition from an existing database (known as **reflection**), is described in [metadata](rel:metadata).
-Next, to tell the `MetaData` we'd actually like to create our `users_table` for real inside the SQLite database, we use `create_all()`, passing it the `engine` instance which points to our database. This will check for the presence of a table first before creating, so it's safe to call multiple times:
+Next, we can issue CREATE TABLE statements derived from our table metadata, by calling `create_all()` and passing it the `engine` instance which points to our database. This will check for the presence of a table first before creating, so it's safe to call multiple times:
{python}
{sql}>>> metadata.create_all(engine) # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
{}
CREATE TABLE users (
id INTEGER NOT NULL,
- name VARCHAR(40),
- fullname VARCHAR(100),
- password VARCHAR(15),
+ name VARCHAR,
+ fullname VARCHAR,
+ password VARCHAR,
PRIMARY KEY (id)
)
{}
COMMIT
-So now our database is created, our initial schema is present, and our SQLAlchemy application knows all about the tables and columns in the database; this information is to be re-used by the Object Relational Mapper, as we'll see now.
-
+Users familiar with the syntax of CREATE TABLE may notice that the VARCHAR columns were generated without a length; on SQLite, this is a valid datatype, but on most databases it's not allowed. So if running this tutorial on a database such as Postgres or MySQL, and you wish to use SQLAlchemy to generate the tables, a "length" may be provided to the `String` type as below:
+
+ {python}
+ Column('name', String(50))
+
+The length field on `String`, as well as similar precision/scale fields available on `Integer`, `Numeric`, etc. are not referenced by SQLAlchemy other than when creating tables.
+
## Define a Python Class to be Mapped {@name=mapping}
-So lets create a rudimentary `User` object to be mapped in the database. This object will for starters have three attributes, `name`, `fullname` and `password`. It only need subclass Python's built-in `object` class (i.e. it's a new style class). We will give it a constructor so that it may conveniently be instantiated with its attributes at once, as well as a `__repr__` method so that we can get a nice string representation of it:
+While the `Table` object defines information about our database, it does not say anything about the definition or behavior of the business objects used by our application; SQLAlchemy views this as a separate concern. To correspond to our `users` table, let's create a rudimentary `User` class. It only need subclass Python's built-in `object` class (i.e. it's a new style class):
{python}
>>> class User(object):
... def __repr__(self):
... return "<User('%s','%s', '%s')>" % (self.name, self.fullname, self.password)
+The class has an `__init__()` and a `__repr__()` method for convenience. These methods are both entirely optional, and can be of any form. SQLAlchemy never calls `__init__()` directly.
+
## Setting up the Mapping
With our `users_table` and `User` class, we now want to map the two together. That's where the SQLAlchemy ORM package comes in. We'll use the `mapper` function to create a **mapping** between `users_table` and `User`:
{python}
>>> from sqlalchemy.orm import mapper
>>> mapper(User, users_table) # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
- <sqlalchemy.orm.mapper.Mapper object at 0x...>
+ <Mapper at 0x...; User>
-The `mapper()` function creates a new `Mapper` object and stores it away for future reference. It also **instruments** the attributes on our `User` class, corresponding to the `users_table` table. The `id`, `name`, `fullname`, and `password` columns in our `users_table` are now instrumented upon our `User` class, meaning it will keep track of all changes to these attributes, and can save and load their values to/from the database. Lets create our first user, 'Ed Jones', and ensure that the object has all three of these attributes:
+The `mapper()` function creates a new `Mapper` object and stores it away for future reference, associated with our class. Let's now create and inspect a `User` object:
{python}
>>> ed_user = User('ed', 'Ed Jones', 'edspassword')
>>> str(ed_user.id)
'None'
-What was that last `id` attribute? That was placed there by the `Mapper`, to track the value of the `id` column in the `users_table`. Since our `User` doesn't exist in the database, its id is `None`. When we save the object, it will get populated automatically with its new id.
+The `id` attribute, which while not defined by our `__init__()` method, exists due to the `id` column present within the `users_table` object. By default, the `mapper` creates class attributes for all columns present within the `Table`. These class attributes exist as Python descriptors, and define **instrumentation** for the mapped class. The functionality of this instrumentation is very rich and includes the ability to track modifications and automatically load new data from the database when needed.
-## Too Verbose ? There are alternatives
+Since we have not yet told SQLAlchemy to persist `Ed Jones` within the database, its id is `None`. When we persist the object later, this attribute will be populated with a newly generated value.
-The full set of steps to map a class, which are to define a `Table`, define a class, and then define a `mapper()`, are fairly verbose and for simple cases may appear overly disjoint. Most popular object relational products use the so-called "active record" approach, where the table definition and its class mapping are all defined at once. With SQLAlchemy, there are two excellent alternatives to its usual configuration which provide this approach:
+## Creating Table, Class and Mapper All at Once Declaratively {@name=declarative}
- * [Elixir](http://elixir.ematia.de/) is a "sister" product to SQLAlchemy, which is a full "declarative" layer built on top of SQLAlchemy. It has existed almost as long as SA itself and defines a rich featureset on top of SA's normal configuration, adding many new capabilities such as plugins, automatic generation of table and column names based on configurations, and an intuitive system of defining relations.
- * [declarative](rel:plugins_declarative) is a so-called "micro-declarative" plugin included with SQLAlchemy 0.4.4 and above. In contrast to Elixir, it maintains the use of the same configurational constructs outlined in this tutorial, except it allows the `Column`, `relation()`, and other constructs to be defined "inline" with the mapped class itself, so that explicit calls to `Table` and `mapper()` are not needed in most cases.
+The preceding approach to configuration involving a `Table`, user-defined class, and `mapper()` call illustrate classical SQLAlchemy usage, which values the highest separation of concerns possible. A large number of applications don't require this degree of separation, and for those SQLAlchemy offers an alternate "shorthand" configurational style called **declarative**. For many applications, this is the only style of configuration needed. Our above example using this style is as follows:
-With either declarative layer it's a good idea to be familiar with SQLAlchemy's "base" configurational style in any case. But now that we have our configuration started, we're ready to look at how to build sessions and query the database; this process is the same regardless of configurational style.
+ {python}
+ >>> from sqlalchemy.ext.declarative import declarative_base
+
+ >>> Base = declarative_base()
+ >>> class User(Base):
+ ... __tablename__ = 'users'
+ ...
+ ... id = Column(Integer, primary_key=True)
+ ... name = Column(String)
+ ... fullname = Column(String)
+ ... password = Column(String)
+ ...
+ ... def __init__(self, name, fullname, password):
+ ... self.name = name
+ ... self.fullname = fullname
+ ... self.password = password
+ ...
+ ... def __repr__(self):
+ ... return "<User('%s','%s', '%s')>" % (self.name, self.fullname, self.password)
-## Creating a Session
+Above, the `declarative_base()` function defines a new class which we name `Base`, from which all of our ORM-enabled classes will derive. Note that we define `Column` objects with no "name" field, since it's inferred from the given attribute name.
+
+The underlying `Table` object created by our `declarative_base()` version of `User` is accessible via the `__table__` attribute:
-We're now ready to start talking to the database. The ORM's "handle" to the database is the `Session`. When we first set up the application, at the same level as our `create_engine()` statement, we define a second object called `Session` (or whatever you want to call it, `create_session`, etc.) which is configured by the `sessionmaker()` function. This function is configurational and need only be called once.
+ {python}
+ >>> users_table = User.__table__
+and the owning `MetaData` object is available as well:
+
+ {python}
+ >>> metadata = Base.metadata
+
+Yet another "declarative" method is available for SQLAlchemy as a third party library called [Elixir](http://elixir.ematia.de/). This is a full-featured configurational product which also includes many higher level mapping configurations built in. Like declarative, once classes and mappings are defined, ORM usage is the same as with a classical SQLAlchemy configuration.
+
+## Creating a Session
+
+We're now ready to start talking to the database. The ORM's "handle" to the database is the `Session`. When we first set up the application, at the same level as our `create_engine()` statement, we define a `Session` class which will serve as a factory for new `Session` objects:
+
{python}
>>> from sqlalchemy.orm import sessionmaker
- >>> Session = sessionmaker(bind=engine, autoflush=True, transactional=True)
+ >>> Session = sessionmaker(bind=engine)
In the case where your application does not yet have an `Engine` when you define your module-level objects, just set it up like this:
{python}
- >>> Session = sessionmaker(autoflush=True, transactional=True)
+ >>> Session = sessionmaker()
Later, when you create your engine with `create_engine()`, connect it to the `Session` using `configure()`:
{python}
>>> Session.configure(bind=engine) # once engine is available
-This `Session` class will create new `Session` objects which are bound to our database and have the transactional characteristics we've configured. Whenever you need to have a conversation with the database, you instantiate a `Session`:
+This custom-made `Session` class will create new `Session` objects which are bound to our database. Other transactional characteristics may be defined when calling `sessionmaker()` as well; these are described in a later chapter. Then, whenever you need to have a conversation with the database, you instantiate a `Session`:
{python}
>>> session = Session()
-The above `Session` is associated with our SQLite `engine`, but it hasn't opened any connections yet. When it's first used, it retrieves a connection from a pool of connections maintained by the `engine`, and holds onto it until we commit all changes and/or close the session object. Because we configured `transactional=True`, there's also a transaction in progress (one notable exception to this is MySQL, when you use its default table style of MyISAM). There's options available to modify this behavior but we'll go with this straightforward version to start.
+The above `Session` is associated with our SQLite `engine`, but it hasn't opened any connections yet. When it's first used, it retrieves a connection from a pool of connections maintained by the `engine`, and holds onto it until we commit all changes and/or close the session object.
-## Saving Objects
+## Adding new Objects
-So saving our `User` is as easy as issuing `save()`:
+To persist our `User` object, we `add()` it to our `Session`:
{python}
- >>> session.save(ed_user)
+ >>> ed_user = User('ed', 'Ed Jones', 'edspassword')
+ >>> session.add(ed_user)
-But you'll notice nothing has happened yet. Well, lets pretend something did, and try to query for our user. This is done using the `query()` method on `Session`. We create a new query representing the set of all `User` objects first. Then we narrow the results by "filtering" down to the user we want; that is, the user whose `name` attribute is `"ed"`. Finally we call `first()` which tells `Query`, "we'd like the first result in this list".
+At this point, the instance is **pending**; no SQL has yet been issued. The `Session` will issue the SQL to persist `Ed Jones` as soon as is needed, using a process known as a **flush**. If we query the database for `Ed Jones`, all pending information will first be flushed, and the query is issued afterwards.
+
+For example, below we create a new `Query` object which loads instances of `User`. We "filter by" the `name` attribute of `ed`, and indicate that we'd like only the first result in the full list of rows. A `User` instance is returned which is equivalent to that which we've added:
{python}
- {sql}>>> session.query(User).filter_by(name='ed').first() # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
+ {sql}>>> our_user = session.query(User).filter_by(name='ed').first() # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
BEGIN
INSERT INTO users (name, fullname, password) VALUES (?, ?, ?)
['ed', 'Ed Jones', 'edspassword']
WHERE users.name = ? ORDER BY users.oid
LIMIT 1 OFFSET 0
['ed']
- {stop}<User('ed','Ed Jones', 'edspassword')>
+ {stop}>>> our_user
+ <User('ed','Ed Jones', 'edspassword')>
+
+In fact, the `Session` has identified that the row returned is the **same** row as one already represented within its internal map of objects, so we actually got back the identical instance as that which we just added:
+
+ {python}
+ >>> ed_user is our_user
+ True
-And we get back our new user. If you view the generated SQL, you'll see that the `Session` issued an `INSERT` statement before querying. The `Session` stores whatever you put into it in memory, and at certain points it issues a **flush**, which issues SQL to the database to store all pending new objects and changes to existing objects. You can manually invoke the flush operation using `flush()`; however when the `Session` is configured to `autoflush`, it's usually not needed.
+The ORM concept at work here is known as an **identity map** and ensures that all operations upon a particular row within a `Session` operate upon the same set of data. Once an object with a particular primary key is present in the `Session`, all SQL queries on that `Session` will always return the same Python object for that particular primary key; it also will raise an error if an attempt is made to place a second, already-persisted object with the same primary key within the session.
-OK, let's do some more operations. We'll create and save three more users:
+We can add more `User` objects at once using `add_all()`:
{python}
- >>> session.save(User('wendy', 'Wendy Williams', 'foobar'))
- >>> session.save(User('mary', 'Mary Contrary', 'xxg527'))
- >>> session.save(User('fred', 'Fred Flinstone', 'blah'))
+ >>> session.add_all([
+ ... User('wendy', 'Wendy Williams', 'foobar'),
+ ... User('mary', 'Mary Contrary', 'xxg527'),
+ ... User('fred', 'Fred Flinstone', 'blah')])
Also, Ed has already decided his password isn't too secure, so lets change it:
{python}
>>> ed_user.password = 'f8s7ccs'
+
+The `Session` is paying attention. It knows, for example, that `Ed Jones` has been modified:
+
+ {python}
+ >>> session.dirty
+ IdentitySet([<User('ed','Ed Jones', 'f8s7ccs')>])
-Then we'll permanently store everything thats been changed and added to the database. We do this via `commit()`:
+and that three new `User` objects are pending:
+
+ {python}
+ >>> session.new # doctest: +NORMALIZE_WHITESPACE
+ IdentitySet([<User('wendy','Wendy Williams', 'foobar')>,
+ <User('mary','Mary Contrary', 'xxg527')>,
+ <User('fred','Fred Flinstone', 'blah')>])
+
+We tell the `Session` that we'd like to issue all remaining changes to the database and commit the transaction, which has been in progress throughout. We do this via `commit()`:
{python}
{sql}>>> session.commit()
If we look at Ed's `id` attribute, which earlier was `None`, it now has a value:
{python}
- >>> ed_user.id
- 1
-
-After each `INSERT` operation, the `Session` assigns all newly generated ids and column defaults to the mapped object instance. For column defaults which are database-generated and are not part of the table's primary key, they'll be loaded when you first reference the attribute on the instance.
-
-One crucial thing to note about the `Session` is that each object instance is cached within the Session, based on its primary key identifier. The reason for this cache is not as much for performance as it is for maintaining an **identity map** of instances. This map guarantees that whenever you work with a particular `User` object in a session, **you always get the same instance back**. As below, reloading Ed gives us the same instance back:
-
- {python}
- {sql}>>> ed_user is session.query(User).filter_by(name='ed').one() # doctest: +NORMALIZE_WHITESPACE
+ {sql}>>> ed_user.id # doctest: +NORMALIZE_WHITESPACE
BEGIN
SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
FROM users
- WHERE users.name = ? ORDER BY users.oid
- LIMIT 2 OFFSET 0
- ['ed']
- {stop}True
+ WHERE users.id = ?
+ [1]
+ {stop}1
-The `get()` method, which queries based on primary key, will not issue any SQL to the database if the given key is already present:
+After the `Session` inserts new rows in the database, all newly generated identifiers and database-generated defaults become available on the instance, either immediately or via load-on-first-access. In this case, the entire row was re-loaded on access because a new transaction was begun after we issued `commit()`. SQLAlchemy by default refreshes data from a previous transaction the first time it's accessed within a new transaction, so that the most recent state is available. The level of reloading is configurable as is described in the chapter on Sessions.
- {python}
- >>> ed_user is session.query(User).get(ed_user.id)
- True
-
## Querying
-A whirlwind tour through querying.
-
-A `Query` is created from the `Session`, relative to a particular class we wish to load.
+A `Query` is created using the `query()` function on `Session`. This function takes a variable number of arguments, which can be any combination of classes and class-instrumented descriptors. Below, we indicate a `Query` which loads `User` instances. When evaluated in an iterative context, the list of `User` objects present is returned:
{python}
- >>> query = session.query(User)
+ {sql}>>> for instance in session.query(User): # doctest: +NORMALIZE_WHITESPACE
+ ... print instance.name, instance.fullname
+ SELECT users.id AS users_id, users.name AS users_name,
+ users.fullname AS users_fullname, users.password AS users_password
+ FROM users ORDER BY users.oid
+ []
+ {stop}ed Ed Jones
+ wendy Wendy Williams
+ mary Mary Contrary
+ fred Fred Flinstone
-Once we have a query, we can start loading objects. The Query object, when first created, represents all the instances of its main class. You can iterate through it directly:
+The `Query` also accepts ORM-instrumented descriptors as arguments. Any time multiple class entities or column-based entities are expressed as arguments to the `query()` function, the return result is expressed as tuples:
{python}
- {sql}>>> for user in session.query(User):
- ... print user.name
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users ORDER BY users.oid
+ {sql}>>> for name, fullname in session.query(User.name, User.fullname): # doctest: +NORMALIZE_WHITESPACE
+ ... print name, fullname
+ SELECT users.name AS users_name, users.fullname AS users_fullname
+ FROM users
[]
- {stop}ed
- wendy
- mary
- fred
+ {stop}ed Ed Jones
+ wendy Wendy Williams
+ mary Mary Contrary
+ fred Fred Flinstone
-...and the SQL will be issued at the point where the query is evaluated as a list. If you apply array slices before iterating, LIMIT and OFFSET are applied to the query:
+Basic operations with `Query` include issuing LIMIT and OFFSET, most conveniently using Python array slices and typically in conjunction with ORDER BY:
{python}
- {sql}>>> for u in session.query(User)[1:3]: #doctest: +NORMALIZE_WHITESPACE
+ {sql}>>> for u in session.query(User).order_by(User.id)[1:3]: #doctest: +NORMALIZE_WHITESPACE
... print u
SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users ORDER BY users.oid
+ FROM users ORDER BY users.id
LIMIT 2 OFFSET 1
[]
{stop}<User('wendy','Wendy Williams', 'foobar')>
<User('mary','Mary Contrary', 'xxg527')>
-Narrowing the results down is accomplished either with `filter_by()`, which uses keyword arguments:
+and filtering results, which is accomplished either with `filter_by()`, which uses keyword arguments:
{python}
- {sql}>>> for user in session.query(User).filter_by(name='ed', fullname='Ed Jones'):
- ... print user
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.fullname = ? AND users.name = ? ORDER BY users.oid
- ['Ed Jones', 'ed']
- {stop}<User('ed','Ed Jones', 'f8s7ccs')>
+ {sql}>>> for name, in session.query(User.name).filter_by(fullname='Ed Jones'): # doctest: +NORMALIZE_WHITESPACE
+ ... print name
+ SELECT users.name AS users_name FROM users
+ WHERE users.fullname = ?
+ ['Ed Jones']
+ {stop}ed
-...or `filter()`, which uses SQL expression language constructs. These allow you to use regular Python operators with the class-level attributes on your mapped class:
+...or `filter()`, which uses more flexible SQL expression language constructs. These allow you to use regular Python operators with the class-level attributes on your mapped class:
{python}
- {sql}>>> for user in session.query(User).filter(User.name=='ed'):
- ... print user
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.name = ? ORDER BY users.oid
- ['ed']
- {stop}<User('ed','Ed Jones', 'f8s7ccs')>
+ {sql}>>> for name, in session.query(User.name).filter(User.fullname=='Ed Jones'): # doctest: +NORMALIZE_WHITESPACE
+ ... print name
+ SELECT users.name AS users_name FROM users
+ WHERE users.fullname = ?
+ ['Ed Jones']
+ {stop}ed
-You can also use the `Column` constructs attached to the `users_table` object to construct SQL expressions:
+The `Query` object is fully *generative*, meaning that most method calls return a new `Query` object upon which further criteria may be added. For example, to query for users named "ed" with a full name of "Ed Jones", you can call `filter()` twice, which joins criteria using `AND`:
{python}
- {sql}>>> for user in session.query(User).filter(users_table.c.name=='ed'):
+ {sql}>>> for user in session.query(User).filter(User.name=='ed').filter(User.fullname=='Ed Jones'): # doctest: +NORMALIZE_WHITESPACE
... print user
SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
FROM users
- WHERE users.name = ? ORDER BY users.oid
- ['ed']
+ WHERE users.name = ? AND users.fullname = ? ORDER BY users.oid
+ ['ed', 'Ed Jones']
{stop}<User('ed','Ed Jones', 'f8s7ccs')>
-Most common SQL operators are available, such as `LIKE`:
- {python}
- {sql}>>> session.query(User).filter(User.name.like('%ed'))[1] # doctest: +NORMALIZE_WHITESPACE
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.name LIKE ? ORDER BY users.oid
- LIMIT 1 OFFSET 1
- ['%ed']
- {stop}<User('fred','Fred Flinstone', 'blah')>
+### Common Filter Operators
-Note above our array index of `1` placed the appropriate LIMIT/OFFSET and returned a scalar result immediately.
+Here's a rundown of some of the most common operators used in `filter()`:
-The `all()`, `one()`, and `first()` methods immediately issue SQL without using an iterative context or array index. `all()` returns a list:
+ * equals
+
+ {python}
+ query.filter(User.name == 'ed')
+
+ * not equals
+
+ {python}
+ query.filter(User.name != 'ed')
+
+ * LIKE
+
+ {python}
+ query.filter(User.name.like('%ed%'))
+
+ * IN
+
+ {python}
+ query.filter(User.name.in_(['ed', 'wendy', 'jack']))
+
+ * IS NULL
+
+ {python}
+ filter(User.name == None)
+
+ * AND
+
+ {python}
+ from sqlalchemy import and_
+ filter(and_(User.name == 'ed', User.fullname == 'Ed Jones'))
+
+ # or call filter()/filter_by() multiple times
+ filter(User.name == 'ed').filter(User.fullname == 'Ed Jones')
+
+ * OR
+
+ {python}
+ from sqlalchemy import or_
+ filter(or_(User.name == 'ed', User.name == 'wendy'))
+
+### Returning Lists and Scalars {@name=scalars}
+
+The `all()`, `one()`, and `first()` methods of `Query` immediately issue SQL and return a non-iterator value. `all()` returns a list:
{python}
>>> query = session.query(User).filter(User.name.like('%ed'))
-
{sql}>>> query.all()
SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
FROM users
['%ed']
{stop}<User('ed','Ed Jones', 'f8s7ccs')>
-and `one()`, applies a limit of *two*, and if not exactly one row returned (no more, no less), raises an error:
+`one()`, applies a limit of *two*, and if not exactly one row returned, raises an error:
{python}
{sql}>>> try:
['%ed']
{stop}Multiple rows returned for one()
-All `Query` methods that don't return a result instead return a new `Query` object, with modifications applied. Therefore you can call many query methods successively to build up the criterion you want:
-
- {python}
- {sql}>>> session.query(User).filter(User.id<2).filter_by(name='ed').\
- ... filter(User.fullname=='Ed Jones').all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.id < ? AND users.name = ? AND users.fullname = ? ORDER BY users.oid
- [2, 'ed', 'Ed Jones']
- {stop}[<User('ed','Ed Jones', 'f8s7ccs')>]
-
-If you need to use other conjunctions besides `AND`, all SQL conjunctions are available explicitly within expressions, such as `and_()` and `or_()`, when using `filter()`:
+### Using Literal SQL {@naqme=literal}
- {python}
- >>> from sqlalchemy import and_, or_
-
- {sql}>>> session.query(User).filter(
- ... and_(User.id<224, or_(User.name=='ed', User.name=='wendy'))
- ... ).all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.id < ? AND (users.name = ? OR users.name = ?) ORDER BY users.oid
- [224, 'ed', 'wendy']
- {stop}[<User('ed','Ed Jones', 'f8s7ccs')>, <User('wendy','Wendy Williams', 'foobar')>]
-
-You also have full ability to use literal strings to construct SQL. For a single criterion, use a string with `filter()`:
+Literal strings can be used flexibly with `Query`. Most methods accept strings in addition to SQLAlchemy clause constructs. For example, `filter()`:
{python}
{sql}>>> for user in session.query(User).filter("id<224").all():
[224, 'fred']
{stop}<User('fred','Fred Flinstone', 'blah')>
-Note that when we use constructed SQL expressions, bind parameters are generated for us automatically; we don't need to worry about them.
-
To use an entirely string-based statement, using `from_statement()`; just ensure that the columns clause of the statement contains the column names normally used by the mapper (below illustrated using an asterisk):
{python}
['ed']
{stop}[<User('ed','Ed Jones', 'f8s7ccs')>]
-`from_statement()` can also accomodate full `select()` constructs. These are described in the [sql](rel:sql):
+## Building a Relation {@name=relation}
- {python}
- >>> from sqlalchemy import select, func
-
- {sql}>>> session.query(User).from_statement(
- ... select(
- ... [users_table],
- ... select([func.max(users_table.c.name)]).label('maxuser')==users_table.c.name)
- ... ).all() # doctest: +NORMALIZE_WHITESPACE
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE (SELECT max(users.name) AS max_1
- FROM users) = users.name
- []
- {stop}[<User('wendy','Wendy Williams', 'foobar')>]
-
-There's also a way to combine scalar results with objects, using `add_column()`. This is often used for functions and aggregates. When `add_column()` (or its cousin `add_entity()`, described later) is used, tuples are returned:
+Now let's consider a second table to be dealt with. Users in our system also can store any number of email addresses associated with their username. This implies a basic one to many association from the `users_table` to a new table which stores email addresses, which we will call `addresses`. Using declarative, we define this table along with its mapped class, `Address`:
{python}
- {sql}>>> for r in session.query(User).\
- ... add_column(select([func.max(users_table.c.name)]).label('maxuser')):
- ... print r # doctest: +NORMALIZE_WHITESPACE
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password, (SELECT max(users.name) AS max_1
- FROM users) AS maxuser
- FROM users ORDER BY users.oid
- []
- {stop}(<User('ed','Ed Jones', 'f8s7ccs')>, u'wendy')
- (<User('wendy','Wendy Williams', 'foobar')>, u'wendy')
- (<User('mary','Mary Contrary', 'xxg527')>, u'wendy')
- (<User('fred','Fred Flinstone', 'blah')>, u'wendy')
+ >>> from sqlalchemy import ForeignKey
+ >>> from sqlalchemy.orm import relation
+ >>> class Address(Base):
+ ... __tablename__ = 'addresses'
+ ... id = Column(Integer, primary_key=True)
+ ... email_address = Column(String, nullable=False)
+ ... user_id = Column(Integer, ForeignKey('users.id'))
+ ...
+ ... user = relation(User, backref='addresses')
+ ...
+ ... def __init__(self, email_address):
+ ... self.email_address = email_address
+ ...
+ ... def __repr__(self):
+ ... return "<Address('%s')>" % self.email_address
-## Building a One-to-Many Relation {@name=onetomany}
+The above class introduces a **foreign key** constraint which references the `users` table. This defines for SQLAlchemy the relationship between the two tables at the database level. The relationship between the `User` and `Address` classes is defined separately using the `relation()` function, which defines an attribute `user` to be placed on the `Address` class, as well as an `addresses` collection to be placed on the `User` class. Such a relation is known as a **bidirectional** relationship. Because of the placement of the foreign key, from `Address` to `User` it is **many to one**, and from `User` to `Address` it is **one to many**. SQLAlchemy is automatically aware of many-to-one/one-to-many based on foreign keys.
-We've spent a lot of time dealing with just one class, and one table. Let's now look at how SQLAlchemy deals with two tables, which have a relationship to each other. Let's say that the users in our system also can store any number of email addresses associated with their username. This implies a basic one to many association from the `users_table` to a new table which stores email addresses, which we will call `addresses`. We will also create a relationship between this new table to the users table, using a `ForeignKey`:
+The `relation()` function is extremely flexible, and could just have easily been defined on the `User` class:
{python}
- >>> from sqlalchemy import ForeignKey
-
- >>> addresses_table = Table('addresses', metadata,
- ... Column('id', Integer, primary_key=True),
- ... Column('email_address', String(100), nullable=False),
- ... Column('user_id', Integer, ForeignKey('users.id')))
-
-Another call to `create_all()` will skip over our `users` table and build just the new `addresses` table:
+ class User(Base):
+ ....
+ addresses = relation("Address", backref="user")
+
+Where above we used the string name `"Addresses"` in the event that the `Address` class was not yet defined. We are also free to not define a backref, and to define the `relation()` only on one class and not the other. It is also possible to define two separate `relation()`s for either direction, which is generally safe for many-to-one and one-to-many relations, but not for many-to-many relations.
+
+We'll need to create the `addresses` table in the database, so we will issue another CREATE from our metadata, which will skip over tables which have already been created:
{python}
{sql}>>> metadata.create_all(engine) # doctest: +NORMALIZE_WHITESPACE
{}
CREATE TABLE addresses (
id INTEGER NOT NULL,
- email_address VARCHAR(100) NOT NULL,
+ email_address VARCHAR NOT NULL,
user_id INTEGER,
PRIMARY KEY (id),
FOREIGN KEY(user_id) REFERENCES users (id)
{}
COMMIT
-For our ORM setup, we're going to start all over again. We will first close out our `Session` and clear all `Mapper` objects:
-
- {python}
- >>> from sqlalchemy.orm import clear_mappers
- >>> session.close()
- >>> clear_mappers()
-
-Our `User` class, still around, reverts to being just a plain old class. Lets create an `Address` class to represent a user's email address:
-
- {python}
- >>> class Address(object):
- ... def __init__(self, email_address):
- ... self.email_address = email_address
- ...
- ... def __repr__(self):
- ... return "<Address('%s')>" % self.email_address
-
-Now comes the fun part. We define a mapper for each class, and associate them using a function called `relation()`. We can define each mapper in any order we want:
-
- {python}
- >>> from sqlalchemy.orm import relation
-
- >>> mapper(User, users_table, properties={ # doctest: +ELLIPSIS
- ... 'addresses':relation(Address, backref='user')
- ... })
- <sqlalchemy.orm.mapper.Mapper object at 0x...>
-
- >>> mapper(Address, addresses_table) # doctest: +ELLIPSIS
- <sqlalchemy.orm.mapper.Mapper object at 0x...>
-
-Above, the new thing we see is that `User` has defined a relation named `addresses`, which will reference a list of `Address` objects. How does it know it's a list ? SQLAlchemy figures it out for you, based on the foreign key relationship between `users_table` and `addresses_table`.
+## Working with Related Objects {@name=related_objects}
-## Working with Related Objects and Backreferences {@name=relation_backref}
-
-Now when we create a `User`, it automatically has this collection present:
+Now when we create a `User`, a blank `addresses` collection will be present. By default, the collection is a Python list. Other collection types, such as sets and dictionaries, are available as well:
{python}
>>> jack = User('jack', 'Jack Bean', 'gjffdd')
>>> jack.addresses
[]
-We are free to add `Address` objects, and the `session` will take care of everything for us.
+We are free to add `Address` objects on our `User` object. In this case we just assign a full list directly:
{python}
- >>> jack.addresses.append(Address(email_address='jack@google.com'))
- >>> jack.addresses.append(Address(email_address='j25@yahoo.com'))
-
-Before we save into the `Session`, lets examine one other thing that's happened here. The `addresses` collection is present on our `User` because we added a `relation()` with that name. But also within the `relation()` function is the keyword `backref`. This keyword indicates that we wish to make a **bi-directional relationship**. What this basically means is that not only did we generate a one-to-many relationship called `addresses` on the `User` class, we also generated a **many-to-one** relationship on the `Address` class. This relationship is self-updating, without any data being flushed to the database, as we can see on one of Jack's addresses:
+ >>> jack.addresses = [Address(email_address='jack@google.com'), Address(email_address='j25@yahoo.com')]
+
+When using a bidirectional relationship, elements added in one direction automatically become visible in the other direction. This is the basic behavior of the **backref** keyword, which maintains the relationship purely in memory, without using any SQL:
{python}
>>> jack.addresses[1]
>>> jack.addresses[1].user
<User('jack','Jack Bean', 'gjffdd')>
-
-Let's save into the session, then close out the session and create a new one...so that we can see how `Jack` and his email addresses come back to us:
+
+Let's add and commit `Jack Bean` to the database. `jack` as well as the two `Address` members in his `addresses` collection are both added to the session at once, using a process known as **cascading**:
{python}
- >>> session.save(jack)
+ >>> session.add(jack)
{sql}>>> session.commit()
- BEGIN
INSERT INTO users (name, fullname, password) VALUES (?, ?, ?)
['jack', 'Jack Bean', 'gjffdd']
INSERT INTO addresses (email_address, user_id) VALUES (?, ?)
['j25@yahoo.com', 5]
COMMIT
- >>> session = Session()
-
Querying for Jack, we get just Jack back. No SQL is yet issued for for Jack's addresses:
{python}
[5]
{stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
-When we accessed the `addresses` collection, SQL was suddenly issued. This is an example of a **lazy loading relation**.
-
-If you want to reduce the number of queries (dramatically, in many cases), we can apply an **eager load** to the query operation. We clear out the session to ensure that a full reload occurs:
-
- {python}
- >>> session.clear()
+When we accessed the `addresses` collection, SQL was suddenly issued. This is an example of a **lazy loading relation**. The `addresses` collection is now loaded and behaves just like an ordinary list.
-Then apply an **option** to the query, indicating that we'd like `addresses` to load "eagerly". SQLAlchemy then constructs a join between the `users` and `addresses` tables:
+If you want to reduce the number of queries (dramatically, in many cases), we can apply an **eager load** to the query operation. With the same query, we may apply an **option** to the query, indicating that we'd like `addresses` to load "eagerly". SQLAlchemy then constructs an outer join between the `users` and `addresses` tables, and loads them at once, populating the `addresses` collection on each `User` object if it's not already populated:
{python}
>>> from sqlalchemy.orm import eagerload
>>> jack.addresses
[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
-
-If you think that query is elaborate, it is ! But SQLAlchemy is just getting started. Note that when using eager loading, *nothing* changes as far as the ultimate results returned. The "loading strategy", as it's called, is designed to be completely transparent in all cases, and is for optimization purposes only. Any query criterion you use to load objects, including ordering, limiting, other joins, etc., should return identical results regardless of the combination of lazily- and eagerly- loaded relationships present.
-An eagerload targeting across multiple relations can use dot separated names:
+SQLAlchemy has the ability to control exactly which attributes and how many levels deep should be joined together in a single SQL query. More information on this feature is available in [advdatamapping_relation](rel:advdatamapping_relation).
- {python}
- query.options(eagerload('orders'), eagerload('orders.items'), eagerload('orders.items.keywords'))
-
-To roll up the above three individual `eagerload()` calls into one, use `eagerload_all()`:
-
- {python}
- query.options(eagerload_all('orders.items.keywords'))
-
## Querying with Joins {@name=joins}
-Which brings us to the next big topic. What if we want to create joins that *do* change the results ? For that, another `Query` tornado is coming....
-
-One way to join two tables together is just to compose a SQL expression. Below we make one up using the `id` and `user_id` attributes on our mapped classes:
+While the eager load created a JOIN specifically to populate a collection, we can also work explicitly with joins in many ways. For example, to construct a simple inner join between `User` and `Address`, we can just `filter()` their related columns together. Below we load the `User` and `Address` entities at once using this method:
{python}
- {sql}>>> session.query(User).filter(User.id==Address.user_id).\
- ... filter(Address.email_address=='jack@google.com').all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
+ {sql}>>> for u, a in session.query(User, Address).filter(User.id==Address.user_id).\
+ ... filter(Address.email_address=='jack@google.com').all(): # doctest: +NORMALIZE_WHITESPACE
+ ... print u, a
+ SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname,
+ users.password AS users_password, addresses.id AS addresses_id,
+ addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
FROM users, addresses
WHERE users.id = addresses.user_id AND addresses.email_address = ? ORDER BY users.oid
['jack@google.com']
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+ {stop}<User('jack','Jack Bean', 'gjffdd')> <Address('jack@google.com')>
-Or we can make a real JOIN construct; below we use the `join()` function available on `Table` to create a `Join` object, then tell the `Query` to use it as our FROM clause:
+Or we can make a real JOIN construct; one way to do so is to use the ORM `join()` function, and tell `Query` to "select from" this join:
{python}
- {sql}>>> session.query(User).select_from(users_table.join(addresses_table)).\
+ >>> from sqlalchemy.orm import join
+ {sql}>>> session.query(User).select_from(join(User, Address)).\
... filter(Address.email_address=='jack@google.com').all()
SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
FROM users JOIN addresses ON users.id = addresses.user_id
['jack@google.com']
{stop}[<User('jack','Jack Bean', 'gjffdd')>]
-Note that the `join()` construct has no problem figuring out the correct join condition between `users_table` and `addresses_table`..the `ForeignKey` we constructed says it all.
+`join()` knows how to join between `User` and `Address` because there's only one foreign key between them. If there were no foreign keys, or several, `join()` would require a third argument indicating the ON clause of the join, in one of the following forms:
-The easiest way to join is automatically, using the `join()` method on `Query`. Just give this method the path from A to B, using the name of a mapped relationship directly:
+ {python}
+ join(User, Address, User.id==Address.user_id) # explicit condition
+ join(User, Address, User.addresses) # specify relation from left to right
+ join(User, Address, 'addresses') # same, using a string
+
+The functionality of `join()` is also available generatively from `Query` itself using `Query.join`. This is most easily used with just the "ON" clause portion of the join, such as:
{python}
- {sql}>>> session.query(User).join('addresses').\
+ {sql}>>> session.query(User).join(User.addresses).\
... filter(Address.email_address=='jack@google.com').all()
SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
FROM users JOIN addresses ON users.id = addresses.user_id
['jack@google.com']
{stop}[<User('jack','Jack Bean', 'gjffdd')>]
-By "A to B", we mean a single relation name or a path of relations. In our case we only have `User->addresses->Address` configured, but if we had a setup like `A->bars->B->bats->C->widgets->D`, a join along all four entities would look like:
+To explicitly specify the target of the join, use tuples to form an argument list similar to the standalone join. This becomes more important when using aliases and similar constructs:
{python}
- session.query(Foo).join(['bars', 'bats', 'widgets']).filter(...)
+ session.query(User).join((Address, User.addresses))
-Each time `join()` is called on `Query`, the **joinpoint** of the query is moved to be that of the endpoint of the join. As above, when we joined from `users_table` to `addresses_table`, all subsequent criterion used by `filter_by()` are against the `addresses` table. When you `join()` again, the joinpoint starts back from the root. We can also backtrack to the beginning explicitly using `reset_joinpoint()`. This instruction will place the joinpoint back at the root `users` table, where subsequent `filter_by()` criterion are again against `users`:
+Multiple joins can be created by passing a list of arguments:
{python}
- {sql}>>> session.query(User).join('addresses').\
- ... filter_by(email_address='jack@google.com').\
- ... reset_joinpoint().filter_by(name='jack').all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users JOIN addresses ON users.id = addresses.user_id
- WHERE addresses.email_address = ? AND users.name = ? ORDER BY users.oid
- ['jack@google.com', 'jack']
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
-
-In all cases, we can get the `User` and the matching `Address` objects back at the same time, by telling the session we want both. This returns the results as a list of tuples:
-
- {python}
- {sql}>>> session.query(User).add_entity(Address).join('addresses').\
- ... filter(Address.email_address=='jack@google.com').all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password, addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM users JOIN addresses ON users.id = addresses.user_id
- WHERE addresses.email_address = ? ORDER BY users.oid
- ['jack@google.com']
- {stop}[(<User('jack','Jack Bean', 'gjffdd')>, <Address('jack@google.com')>)]
+ session.query(Foo).join(Foo.bars, Bar.bats, (Bat, 'widgets'))
+
+The above would produce SQL something like `foo JOIN bars ON <onclause> JOIN bats ON <onclause> JOIN widgets ON <onclause>`.
+
+### Using Aliases {@name=aliases}
-Another common scenario is the need to join on the same table more than once. For example, if we want to find a `User` who has two distinct email addresses, both `jack@google.com` as well as `j25@yahoo.com`, we need to join to the `Addresses` table twice. SQLAlchemy does provide `Alias` objects which can accomplish this; but far easier is just to tell `join()` to alias for you:
+When querying across multiple tables, if the same table needs to be referenced more than once, SQL typically requires that the table be *aliased* with another name, so that it can be distinguished against other occurences of that table. The `Query` supports this most expicitly using the `aliased` construct. Below we join to the `Address` entity twice, to locate a user who has two distinct email addresses at the same time:
{python}
- {sql}>>> session.query(User).\
- ... join('addresses', aliased=True).filter(Address.email_address=='jack@google.com').\
- ... join('addresses', aliased=True).filter(Address.email_address=='j25@yahoo.com').all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id JOIN addresses AS addresses_2 ON users.id = addresses_2.user_id
- WHERE addresses_1.email_address = ? AND addresses_2.email_address = ? ORDER BY users.oid
+ >>> from sqlalchemy.orm import aliased
+ >>> adalias1 = aliased(Address)
+ >>> adalias2 = aliased(Address)
+ {sql}>>> for username, email1, email2 in session.query(User.name, adalias1.email_address, adalias2.email_address).\
+ ... join((adalias1, User.addresses), (adalias2, User.addresses)).\
+ ... filter(adalias1.email_address=='jack@google.com').\
+ ... filter(adalias2.email_address=='j25@yahoo.com'):
+ ... print username, email1, email2 # doctest: +NORMALIZE_WHITESPACE
+ SELECT users.name AS users_name, addresses_1.email_address AS addresses_1_email_address,
+ addresses_2.email_address AS addresses_2_email_address
+ FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id
+ JOIN addresses AS addresses_2 ON users.id = addresses_2.user_id
+ WHERE addresses_1.email_address = ? AND addresses_2.email_address = ?
['jack@google.com', 'j25@yahoo.com']
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+ {stop}jack jack@google.com j25@yahoo.com
-The key thing which occurred above is that our SQL criterion were **aliased** as appropriate corresponding to the alias generated in the most recent `join()` call.
+### Using Subqueries {@name=subqueries}
-The next section describes some "higher level" operators, including `any()` and `has()`, which make patterns like joining to multiple aliases unnecessary in most cases.
+The `Query` is suitable for generating statements which can be used as subqueries. Suppose we wanted to load `User` objects along with a count of how many `Address` records each user has. The best way to generate SQL like this is to get the count of addresses grouped by user ids, and JOIN to the parent. In this case we use a LEFT OUTER JOIN so that we get rows back for those users who don't have any addresses, e.g.:
-### Relation Operators
+ {code}
+ SELECT users.*, adr_count.address_count FROM users LEFT OUTER JOIN
+ (SELECT user_id, count(*) AS address_count FROM addresses GROUP BY user_id) AS adr_count
+ ON users.id=adr_count.user_id
-A summary of all operators usable on relations:
+Using the `Query`, we build a statement like this from the inside out. The `statement` accessor returns a SQL expression representing the statement generated by a particular `Query` - this is an instance of a `select()` construct, which are described in [sql](rel:sql):
+
+ {python}
+ >>> from sqlalchemy.sql import func
+ >>> stmt = session.query(Address.user_id, func.count('*').label('address_count')).group_by(Address.user_id).statement.alias()
+
+The `func` keyword generates SQL functions, and the `alias()` method on `Select` (the return value of `query.statement`) creates a SQL alias, in this case an anonymous one which will have a generated name.
-* Filter on explicit column criterion, combined with a join. Column criterion can make usage of all supported SQL operators and expression constructs:
+Once we have our statement, it behaves like a `Table` construct, which we created for `users` at the top of this tutorial. The columns on the statement are accessible through an attribute called `c`:
- {python}
- {sql}>>> session.query(User).join('addresses').\
- ... filter(Address.email_address=='jack@google.com').all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users JOIN addresses ON users.id = addresses.user_id
- WHERE addresses.email_address = ? ORDER BY users.oid
- ['jack@google.com']
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+ {python}
+ {sql}>>> for u, count in session.query(User, stmt.c.address_count).outerjoin((stmt, User.id==stmt.c.user_id)): # doctest: +NORMALIZE_WHITESPACE
+ ... print u, count
+ SELECT users.id AS users_id, users.name AS users_name,
+ users.fullname AS users_fullname, users.password AS users_password,
+ anon_1.address_count AS anon_1_address_count
+ FROM users LEFT OUTER JOIN (SELECT addresses.user_id AS user_id, count(?) AS address_count
+ FROM addresses GROUP BY addresses.user_id) AS anon_1 ON users.id = anon_1.user_id
+ ORDER BY users.oid
+ ['*']
+ {stop}<User('ed','Ed Jones', 'f8s7ccs')> None
+ <User('wendy','Wendy Williams', 'foobar')> None
+ <User('mary','Mary Contrary', 'xxg527')> None
+ <User('fred','Fred Flinstone', 'blah')> None
+ <User('jack','Jack Bean', 'gjffdd')> 2
- Criterion placed in `filter()` usually correspond to the last `join()` call; if the join was specified with `aliased=True`, class-level criterion against the join's target (or targets) will be appropriately aliased as well.
+### Using EXISTS
- {python}
- {sql}>>> session.query(User).join('addresses', aliased=True).\
- ... filter(Address.email_address=='jack@google.com').all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id
- WHERE addresses_1.email_address = ? ORDER BY users.oid
- ['jack@google.com']
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+The EXISTS keyword in SQL is a boolean operator which returns True if the given expression contains any rows. It may be used in many scenarios in place of joins, and is also useful for locating rows which do not have a corresponding row in a related table.
-* Filter_by on key=value criterion, combined with a join. Same as `filter()` on column criterion except keyword arguments are used.
+There is an explicit EXISTS construct, which looks like this:
- {python}
- {sql}>>> session.query(User).join('addresses').\
- ... filter_by(email_address='jack@google.com').all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users JOIN addresses ON users.id = addresses.user_id
- WHERE addresses.email_address = ? ORDER BY users.oid
- ['jack@google.com']
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
-
-* Filter on explicit column criterion using `any()` (for collections) or `has()` (for scalar relations). This is a more succinct method than joining, as an `EXISTS` subquery is generated automatically. `any()` means, "find all parent items where any child item of its collection meets this criterion":
+ {python}
+ >>> from sqlalchemy.sql import exists
+ >>> stmt = exists().where(Address.user_id==User.id)
+ {sql}>>> for name, in session.query(User.name).filter(stmt): # doctest: +NORMALIZE_WHITESPACE
+ ... print name
+ SELECT users.name AS users_name
+ FROM users
+ WHERE EXISTS (SELECT *
+ FROM addresses
+ WHERE addresses.user_id = users.id)
+ []
+ {stop}jack
- {python}
- {sql}>>> session.query(User).\
- ... filter(User.addresses.any(Address.email_address=='jack@google.com')).all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE EXISTS (SELECT 1
- FROM addresses
- WHERE users.id = addresses.user_id AND addresses.email_address = ?) ORDER BY users.oid
- ['jack@google.com']
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+The `Query` features several operators which make usage of EXISTS automatically. Above, the statement can be expressed along the `User.addresses` relation using `any()`:
+
+ {python}
+ {sql}>>> for name, in session.query(User.name).filter(User.addresses.any()): # doctest: +NORMALIZE_WHITESPACE
+ ... print name
+ SELECT users.name AS users_name
+ FROM users
+ WHERE EXISTS (SELECT 1
+ FROM addresses
+ WHERE users.id = addresses.user_id)
+ []
+ {stop}jack
- `has()` means, "find all parent items where the child item meets this criterion":
+`any()` takes criterion as well, to limit the rows matched:
- {python}
- {sql}>>> session.query(Address).\
- ... filter(Address.user.has(User.name=='jack')).all()
- SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM addresses
- WHERE EXISTS (SELECT 1
- FROM users
- WHERE users.id = addresses.user_id AND users.name = ?) ORDER BY addresses.oid
- ['jack']
- {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
+ {python}
+ {sql}>>> for name, in session.query(User.name).filter(User.addresses.any(Address.email_address.like('%google%'))): # doctest: +NORMALIZE_WHITESPACE
+ ... print name
+ SELECT users.name AS users_name
+ FROM users
+ WHERE EXISTS (SELECT 1
+ FROM addresses
+ WHERE users.id = addresses.user_id AND addresses.email_address LIKE ?)
+ ['%google%']
+ {stop}jack
- Both `has()` and `any()` also accept keyword arguments which are interpreted against the child classes' attributes:
+`has()` is the same operator as `any()` for many-to-one relations (note the `~` operator here too, which means "NOT"):
- {python}
- {sql}>>> session.query(User).\
- ... filter(User.addresses.any(email_address='jack@google.com')).all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE EXISTS (SELECT 1
- FROM addresses
- WHERE users.id = addresses.user_id AND addresses.email_address = ?) ORDER BY users.oid
- ['jack@google.com']
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+ {python}
+ {sql}>>> session.query(Address).filter(~Address.user.has(User.name=='jack')).all() # doctest: +NORMALIZE_WHITESPACE
+ SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address,
+ addresses.user_id AS addresses_user_id
+ FROM addresses
+ WHERE NOT (EXISTS (SELECT 1
+ FROM users
+ WHERE users.id = addresses.user_id AND users.name = ?)) ORDER BY addresses.oid
+ ['jack']
+ {stop}[]
-* Filter_by on instance identity criterion. When comparing to a related instance, `filter_by()` will in most cases not need to reference the child table, since a child instance already contains enough information with which to generate criterion against the parent table. `filter_by()` uses an equality comparison for all relationship types. For many-to-one and one-to-one, this represents all objects which reference the given child object:
-
- {python}
- # locate a user
- {sql}>>> user = session.query(User).filter(User.name=='jack').one() #doctest: +NORMALIZE_WHITESPACE
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.name = ? ORDER BY users.oid
- LIMIT 2 OFFSET 0
- ['jack']
- {stop}
-
- # use the user in a filter_by() expression
- {sql}>>> session.query(Address).filter_by(user=user).all()
- SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM addresses
- WHERE ? = addresses.user_id ORDER BY addresses.oid
- [5]
- {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
+### Common Relation Operators {@name=relationop}
- For one-to-many and many-to-many, it represents all objects which contain the given child object in the related collection:
+Here's all the operators which build on relations:
+ * equals (used for many-to-one)
+
{python}
- # locate an address
- {sql}>>> address = session.query(Address).\
- ... filter(Address.email_address=='jack@google.com').one() #doctest: +NORMALIZE_WHITESPACE
- SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM addresses
- WHERE addresses.email_address = ? ORDER BY addresses.oid
- LIMIT 2 OFFSET 0
- {stop}['jack@google.com']
-
- # use the address in a filter_by expression
- {sql}>>> session.query(User).filter_by(addresses=address).all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.id = ? ORDER BY users.oid
- [5]
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
-
-* Select instances with a particular parent. This is the "reverse" operation of filtering by instance identity criterion; the criterion is against a relation pointing *to* the desired class, instead of one pointing *from* it. This will utilize the same "optimized" query criterion, usually not requiring any joins:
+ query.filter(Address.user == someuser)
+
+ * not equals (used for many-to-one)
{python}
- {sql}>>> session.query(Address).with_parent(user, property='addresses').all()
- SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM addresses
- WHERE ? = addresses.user_id ORDER BY addresses.oid
- [5]
- {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
-
-* Filter on a many-to-one/one-to-one instance identity criterion. The class-level `==` operator will act the same as `filter_by()` for a scalar relation:
+ query.filter(Address.user != someuser)
+ * IS NULL (used for many-to-one)
+
{python}
- {sql}>>> session.query(Address).filter(Address.user==user).all()
- SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM addresses
- WHERE ? = addresses.user_id ORDER BY addresses.oid
- [5]
- {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
-
- whereas the `!=` operator will generate a negated EXISTS clause:
-
+ query.filter(Address.user == None)
+
+ * contains (used for one-to-many and many-to-many collections)
+
{python}
- {sql}>>> session.query(Address).filter(Address.user!=user).all()
- SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM addresses
- WHERE NOT (EXISTS (SELECT 1
- FROM users
- WHERE users.id = addresses.user_id AND users.id = ?)) ORDER BY addresses.oid
- [5]
- {stop}[]
-
- a comparison to `None` also generates an IS NULL clause for a many-to-one relation:
-
+ query.filter(User.addresses.contains(someaddress))
+
+ * any (used for one-to-many and many-to-many collections)
+
{python}
- {sql}>>> session.query(Address).filter(Address.user==None).all()
- SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM addresses
- WHERE addresses.user_id IS NULL ORDER BY addresses.oid
- []
- {stop}[]
-
-* Filter on a one-to-many instance identity criterion. The `contains()` operator returns all parent objects which contain the given object as one of its collection members:
-
+ query.filter(User.addresses.any(Address.email_address == 'bar'))
+
+ # also takes keyword arguments:
+ query.filter(User.addresses.any(email_address='bar'))
+
+ * has (used for many-to-one)
+
{python}
- {sql}>>> session.query(User).filter(User.addresses.contains(address)).all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE users.id = ? ORDER BY users.oid
- [5]
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
-
-* Filter on a multiple one-to-many instance identity criterion. The `==` operator can be used with a collection-based attribute against a list of items, which will generate multiple `EXISTS` clauses:
-
+ query.filter(Address.user.has(name='ed'))
+
+ * with_parent (used for any relation)
+
{python}
- {sql}>>> addresses = session.query(Address).filter(Address.user==user).all()
- SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id
- FROM addresses
- WHERE ? = addresses.user_id ORDER BY addresses.oid
- [5]
- {stop}
-
- {sql}>>> session.query(User).filter(User.addresses == addresses).all()
- SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
- FROM users
- WHERE (EXISTS (SELECT 1
- FROM addresses
- WHERE users.id = addresses.user_id AND addresses.id = ?)) AND (EXISTS (SELECT 1
- FROM addresses
- WHERE users.id = addresses.user_id AND addresses.id = ?)) ORDER BY users.oid
- [1, 2]
- {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+ session.query(Address).with_parent(someuser, 'addresses')
## Deleting
[None, 2]
DELETE FROM users WHERE users.id = ?
[5]
- SELECT count(users.id) AS count_1
+ SELECT count(1) AS count_1
FROM users
WHERE users.name = ?
['jack']
{sql}>>> session.query(Address).filter(
... Address.email_address.in_(['jack@google.com', 'j25@yahoo.com'])
... ).count() # doctest: +NORMALIZE_WHITESPACE
- SELECT count(addresses.id) AS count_1
+ SELECT count(1) AS count_1
FROM addresses
WHERE addresses.email_address IN (?, ?)
['jack@google.com', 'j25@yahoo.com']
{stop}2
-Uh oh, they're still there ! Analyzing the flush SQL, we can see that the `user_id` column of each address was set to NULL, but the rows weren't deleted. SQLAlchemy doesn't assume that deletes cascade, you have to tell it so.
+Uh oh, they're still there ! Analyzing the flush SQL, we can see that the `user_id` column of each address was set to NULL, but the rows weren't deleted. SQLAlchemy doesn't assume that deletes cascade, you have to tell it to do so.
-So let's rollback our work, and start fresh with new mappers that express the relationship the way we want:
+### Configuring delete/delete-orphan Cascade {@name=cascade}
+
+We will configure **cascade** options on the `User.addresses` relation to change the behavior. While SQLAlchemy allows you to add new attributes and relations to mappings at any point in time, in this case the existing relation needs to be removed, so we need to tear down the mappings completely and start again. This is not a typical operation and is here just for illustrative purposes.
+
+Removing all ORM state is as follows:
{python}
- {sql}>>> session.rollback() # roll back the transaction
- ROLLBACK
-
- >>> session.clear() # clear the session
+ >>> session.close() # roll back and close the transaction
+ >>> from sqlalchemy.orm import clear_mappers
>>> clear_mappers() # clear mappers
-We need to tell the `addresses` relation on `User` that we'd like session.delete() operations to cascade down to the child `Address` objects. Further, we also want `Address` objects which get detached from their parent `User`, whether or not the parent is deleted, to be deleted. For these behaviors we use two **cascade options** `delete` and `delete-orphan`, using the string-based `cascade` option to the `relation()` function:
+Below, we use `mapper()` to reconfigure an ORM mapping for `User` and `Address`, on our existing but currently un-mapped classes. The `User.addresses` relation now has `delete, delete-orphan` cascade on it, which indicates that DELETE operations will cascade to attached `Address` objects as well as `Address` objects which are removed from their parent:
{python}
>>> mapper(User, users_table, properties={ # doctest: +ELLIPSIS
... 'addresses':relation(Address, backref='user', cascade="all, delete, delete-orphan")
... })
- <sqlalchemy.orm.mapper.Mapper object at 0x...>
+ <Mapper at 0x...; User>
+ >>> addresses_table = Address.__table__
>>> mapper(Address, addresses_table) # doctest: +ELLIPSIS
- <sqlalchemy.orm.mapper.Mapper object at 0x...>
+ <Mapper at 0x...; Address>
-Now when we load Jack, removing an address from his `addresses` collection will result in that `Address` being deleted:
+Now when we load Jack (below using `get()`, which loads by primary key), removing an address from his `addresses` collection will result in that `Address` being deleted:
{python}
# load Jack by primary key
- {sql}>>> jack = session.query(User).get(jack.id) #doctest: +NORMALIZE_WHITESPACE
+ {sql}>>> jack = session.query(User).get(5) #doctest: +NORMALIZE_WHITESPACE
BEGIN
SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password
FROM users
WHERE ? = addresses.user_id ORDER BY addresses.oid
[5]
{stop}
-
+
# only one address remains
{sql}>>> session.query(Address).filter(
... Address.email_address.in_(['jack@google.com', 'j25@yahoo.com'])
... ).count() # doctest: +NORMALIZE_WHITESPACE
DELETE FROM addresses WHERE addresses.id = ?
[2]
- SELECT count(addresses.id) AS count_1
+ SELECT count(1) AS count_1
FROM addresses
WHERE addresses.email_address IN (?, ?)
['jack@google.com', 'j25@yahoo.com']
{python}
>>> session.delete(jack)
- {sql}>>> session.commit()
+ {sql}>>> session.query(User).filter_by(name='jack').count() # doctest: +NORMALIZE_WHITESPACE
DELETE FROM addresses WHERE addresses.id = ?
[1]
DELETE FROM users WHERE users.id = ?
[5]
- COMMIT
- {stop}
-
- {sql}>>> session.query(User).filter_by(name='jack').count() # doctest: +NORMALIZE_WHITESPACE
- BEGIN
- SELECT count(users.id) AS count_1
+ SELECT count(1) AS count_1
FROM users
WHERE users.name = ?
['jack']
{sql}>>> session.query(Address).filter(
... Address.email_address.in_(['jack@google.com', 'j25@yahoo.com'])
... ).count() # doctest: +NORMALIZE_WHITESPACE
- SELECT count(addresses.id) AS count_1
+ SELECT count(1) AS count_1
FROM addresses
WHERE addresses.email_address IN (?, ?)
['jack@google.com', 'j25@yahoo.com']
We're moving into the bonus round here, but lets show off a many-to-many relationship. We'll sneak in some other features too, just to take a tour. We'll make our application a blog application, where users can write `BlogPost`s, which have `Keywords` associated with them.
-First some new tables:
+The declarative setup is as follows:
{python}
>>> from sqlalchemy import Text
- >>> post_table = Table('posts', metadata,
- ... Column('id', Integer, primary_key=True),
- ... Column('user_id', Integer, ForeignKey('users.id')),
- ... Column('headline', String(255), nullable=False),
- ... Column('body', Text)
- ... )
-
+
+ >>> # association table
>>> post_keywords = Table('post_keywords', metadata,
- ... Column('post_id', Integer, ForeignKey('posts.id')),
- ... Column('keyword_id', Integer, ForeignKey('keywords.id')))
-
- >>> keywords_table = Table('keywords', metadata,
- ... Column('id', Integer, primary_key=True),
- ... Column('keyword', String(50), nullable=False, unique=True))
+ ... Column('post_id', Integer, ForeignKey('posts.id')),
+ ... Column('keyword_id', Integer, ForeignKey('keywords.id'))
+ ... )
+
+ >>> class BlogPost(Base):
+ ... __tablename__ = 'posts'
+ ...
+ ... id = Column(Integer, primary_key=True)
+ ... user_id = Column(Integer, ForeignKey('users.id'))
+ ... headline = Column(String(255), nullable=False)
+ ... body = Column(Text)
+ ...
+ ... # many to many BlogPost<->Keyword
+ ... keywords = relation('Keyword', secondary=post_keywords, backref='posts')
+ ...
+ ... def __init__(self, headline, body, author):
+ ... self.author = author
+ ... self.headline = headline
+ ... self.body = body
+ ...
+ ... def __repr__(self):
+ ... return "BlogPost(%r, %r, %r)" % (self.headline, self.body, self.author)
+
+ >>> class Keyword(Base):
+ ... __tablename__ = 'keywords'
+ ...
+ ... id = Column(Integer, primary_key=True)
+ ... keyword = Column(String(50), nullable=False, unique=True)
+ ...
+ ... def __init__(self, keyword):
+ ... self.keyword = keyword
+
+Above, the many-to-many relation above is `BlogPost.keywords`. The defining feature of a many to many relation is the `secondary` keyword argument which references a `Table` object representing the association table. This table only contains columns which reference the two sides of the relation; if it has *any* other columns, such as its own primary key, or foreign keys to other tables, SQLAlchemy requires a different usage pattern called the "association object", described at [advdatamapping_relation_patterns_association](rel:advdatamapping_relation_patterns_association).
+
+The many-to-many relation is also bi-directional using the `backref` keyword. This is the one case where usage of `backref` is generally required, since if a separate `posts` relation were added to the `Keyword` entity, both relations would independently add and remove rows from the `post_keywords` table and produce conflicts.
+
+We would also like our `BlogPost` class to have an `author` field. We will add this as another bidirectional relationship, except one issue we'll have is that a single user might have lots of blog posts. When we access `User.posts`, we'd like to be able to filter results further so as not to load the entire collection. For this we use a setting accepted by `relation()` called `lazy='dynamic'`, which configures an alternate **loader strategy** on the attribute. To use it on the "reverse" side of a `relation()`, we use the `backref()` function:
+
+ {python}
+ >>> from sqlalchemy.orm import backref
+ >>> # "dynamic" loading relation to User
+ >>> BlogPost.author = relation(User, backref=backref('posts', lazy='dynamic'))
+
+Create new tables:
+ {python}
{sql}>>> metadata.create_all(engine) # doctest: +NORMALIZE_WHITESPACE
PRAGMA table_info("users")
{}
{}
COMMIT
-Then some classes:
-
- {python}
- >>> class BlogPost(object):
- ... def __init__(self, headline, body, author):
- ... self.author = author
- ... self.headline = headline
- ... self.body = body
- ... def __repr__(self):
- ... return "BlogPost(%r, %r, %r)" % (self.headline, self.body, self.author)
-
- >>> class Keyword(object):
- ... def __init__(self, keyword):
- ... self.keyword = keyword
-
-And the mappers. `BlogPost` will reference `User` via its `author` attribute:
-
- {python}
- >>> from sqlalchemy.orm import backref
-
- >>> mapper(Keyword, keywords_table) # doctest: +ELLIPSIS
- <sqlalchemy.orm.mapper.Mapper object at 0x...>
-
- >>> mapper(BlogPost, post_table, properties={ # doctest: +ELLIPSIS
- ... 'author':relation(User, backref=backref('posts', lazy='dynamic')),
- ... 'keywords':relation(Keyword, secondary=post_keywords)
- ... })
- <sqlalchemy.orm.mapper.Mapper object at 0x...>
-
-There's three new things in the above mapper:
-
- * the `User` relation has a backref, like we've used before, except this time it references a function called `backref()`. This function is used when yo'd like to specify keyword options for the backwards relationship.
- * the keyword option we specified to `backref()` is `lazy="dynamic"`. This sets a default **loader strategy** on the attribute, in this case a special strategy that allows partial loading of results.
- * The `keywords` relation uses a keyword argument `secondary` to indicate the **association table** for the many to many relationship from `BlogPost` to `Keyword`.
-
Usage is not too different from what we've been doing. Let's give Wendy some blog posts:
{python}
['wendy']
>>> post = BlogPost("Wendy's Blog Post", "This is a test", wendy)
- >>> session.save(post)
+ >>> session.add(post)
We're storing keywords uniquely in the database, but we know that we don't have any yet, so we can just create them:
>>> post.keywords.append(Keyword('wendy'))
>>> post.keywords.append(Keyword('firstpost'))
-We can now look up all blog posts with the keyword 'firstpost'. We'll use a special collection operator `any` to locate "blog posts where any of its keywords has the keyword string 'firstpost'":
+We can now look up all blog posts with the keyword 'firstpost'. We'll use the `any` operator to locate "blog posts where any of its keywords has the keyword string 'firstpost'":
{python}
{sql}>>> session.query(BlogPost).filter(BlogPost.keywords.any(keyword='firstpost')).all()
If we want to look up just Wendy's posts, we can tell the query to narrow down to her as a parent:
{python}
- {sql}>>> session.query(BlogPost).with_parent(wendy).\
+ {sql}>>> session.query(BlogPost).filter(BlogPost.author==wendy).\
... filter(BlogPost.keywords.any(keyword='firstpost')).all()
SELECT posts.id AS posts_id, posts.user_id AS posts_user_id, posts.headline AS posts_headline, posts.body AS posts_body
FROM posts
- Using the Session {@name=unitofwork}
+Using the Session {@name=unitofwork}
============
The [Mapper](rel:advdatamapping) is the entrypoint to the configurational API of the SQLAlchemy object relational mapper. But the primary object one works with when using the ORM is the [Session](rel:docstrings_sqlalchemy.orm.session_Session).
In the most general sense, the `Session` establishes all conversations with the database and represents a "holding zone" for all the mapped instances which you've loaded or created during its lifespan. It implements the [Unit of Work](http://martinfowler.com/eaaCatalog/unitOfWork.html) pattern, which means it keeps track of all changes which occur, and is capable of **flushing** those changes to the database as appropriate. Another important facet of the `Session` is that it's also maintaining **unique** copies of each instance, where "unique" means "only one object with a particular primary key" - this pattern is called the [Identity Map](http://martinfowler.com/eaaCatalog/identityMap.html).
-Beyond that, the `Session` implements an interface which let's you move objects in or out of the session in a variety of ways, it provides the entryway to a `Query` object which is used to query the database for data, it is commonly used to provide transactional boundaries (though this is optional), and it also can serve as a configurational "home base" for one or more `Engine` objects, which allows various vertical and horizontal partitioning strategies to be achieved.
+Beyond that, the `Session` implements an interface which let's you move objects in or out of the session in a variety of ways, it provides the entryway to a `Query` object which is used to query the database for data, and it also provides a transactional context for SQL operations which rides on top of the transactional capabilities of `Engine` and `Connection` objects.
## Getting a Session
-The `Session` object exists just as a regular Python object, which can be directly instantiated. However, it takes a fair amount of keyword options, several of which you probably want to set explicitly. It's fairly inconvenient to deal with the "configuration" of a session every time you want to create one. Therefore, SQLAlchemy recommends the usage of a helper function called `sessionmaker()`, which typically you call only once for the lifespan of an application. This function creates a customized `Session` subclass for you, with your desired configurational arguments pre-loaded. Then, whenever you need a new `Session`, you use your custom `Session` class with no arguments to create the session.
+`Session` is a regular Python class which can be directly instantiated. However, to standardize how sessions are configured and acquired, the `sessionmaker()` function is normally used to create a top level `Session` configuration which can then be used throughout an application without the need to repeat the configurational arguments.
### Using a sessionmaker() Configuration {@name=sessionmaker}
from sqlalchemy.orm import sessionmaker
# create a configured "Session" class
- Session = sessionmaker(autoflush=True, transactional=True)
+ Session = sessionmaker(bind=some_engine)
# create a Session
sess = Session()
# work with sess
- sess.save(x)
+ myobject = MyObject('foo', 'bar')
+ sess.add(myobject)
sess.commit()
# close when finished
sess.close()
-Above, the `sessionmaker` call creates a class for us, which we assign to the name `Session`. This class is a subclass of the actual `sqlalchemy.orm.session.Session` class, which will instantiate with the arguments of `autoflush=True` and `transactional=True`.
+Above, the `sessionmaker` call creates a class for us, which we assign to the name `Session`. This class is a subclass of the actual `sqlalchemy.orm.session.Session` class, which will instantiate with a particular bound engine.
When you write your application, place the call to `sessionmaker()` somewhere global, and then make your new `Session` class available to the rest of your application.
-### Binding Session to an Engine or Connection {@name=binding}
+### Binding Session to an Engine {@name=binding}
-In our previous example regarding `sessionmaker()`, nowhere did we specify how our session would connect to our database. When the session is configured in this manner, it will look for a database engine to connect with via the `Table` objects that it works with - the chapter called [metadata_tables_binding](rel:metadata_tables_binding) describes how to associate `Table` objects directly with a source of database connections.
-
-However, it is often more straightforward to explicitly tell the session what database engine (or engines) you'd like it to communicate with. This is particularly handy with multiple-database scenarios where the session can be used as the central point of configuration. To achieve this, the constructor keyword `bind` is used for a basic single-database configuration:
-
- {python}
- # create engine
- engine = create_engine('postgres://...')
-
- # bind custom Session class to the engine
- Session = sessionmaker(bind=engine, autoflush=True, transactional=True)
-
- # work with the session
- sess = Session()
-
-One common issue with the above scenario is that an application will often organize its global imports before it ever connects to a database. Since the `Session` class created by `sessionmaker()` is meant to be a global application object (note we are saying the session *class*, not a session *instance*), we may not have a `bind` argument available. For this, the `Session` class returned by `sessionmaker()` supports post-configuration of all options, through its method `configure()`:
+In our previous example regarding `sessionmaker()`, we specified a `bind` for a particular `Engine`. If we'd like to construct a `sessionmaker()` without an engine available and bind it later on, or to specify other options to an existing `sessionmaker()`, we may use the `configure()` method:
{python}
# configure Session class with desired options
- Session = sessionmaker(autoflush=True, transactional=True)
+ Session = sessionmaker()
# later, we create the engine
engine = create_engine('postgres://...')
# work with the session
sess = Session()
-The `Session` also has the ability to be bound to multiple engines. Descriptions of these scenarios are described in [unitofwork_partitioning](rel:unitofwork_partitioning).
+It's actually entirely optional to bind a Session to an engine. If the underlying mapped `Table` objects use "bound" metadata, the `Session` will make use of the bound engine instead (or will even use multiple engines if multiple binds are present within the mapped tables). "Bound" metadata is described at [metadata_tables_binding](rel:metadata_tables_binding).
+The `Session` also has the ability to be bound to multiple engines explicitly. Descriptions of these scenarios are described in [unitofwork_partitioning](rel:unitofwork_partitioning).
-#### Binding Session to a Connection {@name=connection}
+### Binding Session to a Connection {@name=connection}
-The examples involving `bind` so far are dealing with the `Engine` object, which is, like the `Session` class itself, a global configurational object. The `Session` can also be bound to an individual database `Connection`. The reason you might want to do this is if your application controls the boundaries of transactions using distinct `Transaction` objects (these objects are described in [dbengine_transactions](rel:dbengine_transactions)). You'd have a transactional `Connection`, and then you'd want to work with an ORM-level `Session` which participates in that transaction. Since `Connection` is definitely not a globally-scoped object in all but the most rudimental commandline applications, you can bind an individual `Session()` instance to a particular `Connection` not at class configuration time, but at session instance construction time:
+The `Session` can also be explicitly bound to an individual database `Connection`. Reasons for doing this may include to join a `Session` with an ongoing transaction local to a specific `Connection` object, or to bypass connection pooling by just having connections persistently checked out and associated with distinct, long running sessions:
{python}
# global application scope. create Session class, engine
- Session = sessionmaker(autoflush=True, transactional=True)
+ Session = sessionmaker()
engine = create_engine('postgres://...')
### Using create_session() {@name=createsession}
-As an alternative to `sessionmaker()`, `create_session()` exists literally as a function which calls the normal `Session` constructor directly. All arguments are passed through and the new `Session` object is returned:
+As an alternative to `sessionmaker()`, `create_session()` is a function which calls the normal `Session` constructor directly. All arguments are passed through and the new `Session` object is returned:
{python}
- session = create_session(bind=myengine)
-
-The `create_session()` function doesn't add any functionality to the regular `Session`, it just sets up a default argument set of `autoflush=False, transactional=False`. But also, by calling `create_session()` instead of instantiating `Session` directly, you leave room in your application to change the type of session which the function creates. For example, an application which is calling `create_session()` in many places, which is typical for a pre-0.4 application, can be changed to use a `sessionmaker()` by just assigning the return of `sessionmaker()` to the `create_session` name:
+ session = create_session(bind=myengine, autocommit=True, autoflush=False)
- {python}
- # change from:
- from sqlalchemy.orm import create_session
+### Configurational Arguments {@name=configuration}
- # to:
- create_session = sessionmaker()
+Configurational arguments accepted by `sessionmaker()` and `create_session()` are the same as that of the `Session` class itself, and are described at [docstrings_sqlalchemy.orm_modfunc_sessionmaker](rel:docstrings_sqlalchemy.orm_modfunc_sessionmaker).
## Using the Session
-A typical session conversation starts with creating a new session, or acquiring one from an ongoing context. You save new objects and load existing ones, make changes, mark some as deleted, and then persist your changes to the database. If your session is transactional, you use `commit()` to persist any remaining changes and to commit the transaction. If not, you call `flush()` which will flush any remaining data to the database.
-
-Below, we open a new `Session` using a configured `sessionmaker()`, make some changes, and commit:
-
- {python}
- # configured Session class
- Session = sessionmaker(autoflush=True, transactional=True)
-
- sess = Session()
- d = Data(value=10)
- sess.save(d)
- d2 = sess.query(Data).filter(Data.value==15).one()
- d2.value = 19
- sess.commit()
-
### Quickie Intro to Object States {@name=states}
It's helpful to know the states which an instance can have within a session:
* *Transient* - an instance that's not in a session, and is not saved to the database; i.e. it has no database identity. The only relationship such an object has to the ORM is that its class has a `mapper()` associated with it.
-* *Pending* - when you `save()` a transient instance, it becomes pending. It still wasn't actually flushed to the database yet, but it will be when the next flush occurs.
+* *Pending* - when you `add()` a transient instance, it becomes pending. It still wasn't actually flushed to the database yet, but it will be when the next flush occurs.
* *Persistent* - An instance which is present in the session and has a record in the database. You get persistent instances by either flushing so that the pending instances become persistent, or by querying the database for existing instances (or moving persistent instances from other sessions into your local session).
You typically invoke `Session()` when you first need to talk to your database, and want to save some objects or load some existing ones. Then, you work with it, save your changes, and then dispose of it....or at the very least `close()` it. It's not a "global" kind of object, and should be handled more like a "local variable", as it's generally **not** safe to use with concurrent threads. Sessions are very inexpensive to make, and don't use any resources whatsoever until they are first used...so create some !
- There is also a pattern whereby you're using a **contextual session**, this is described later in [unitofwork_contextual](rel:unitofwork_contextual). In this pattern, a helper object is maintaining a `Session` for you, most commonly one that is local to the current thread (and sometimes also local to an application instance). SQLAlchemy 0.4 has worked this pattern out such that it still *looks* like you're creating a new session as you need one...so in that case, it's still a guaranteed win to just say `Session()` whenever you want a session.
+ There is also a pattern whereby you're using a **contextual session**, this is described later in [unitofwork_contextual](rel:unitofwork_contextual). In this pattern, a helper object is maintaining a `Session` for you, most commonly one that is local to the current thread (and sometimes also local to an application instance). SQLAlchemy has worked this pattern out such that it still *looks* like you're creating a new session as you need one...so in that case, it's still a guaranteed win to just say `Session()` whenever you want a session.
* Is the Session a cache ?
But the bigger point here is, you should not *want* to use the session with multiple concurrent threads. That would be like having everyone at a restaurant all eat from the same plate. The session is a local "workspace" that you use for a specific set of tasks; you don't want to, or need to, share that session with other threads who are doing some other task. If, on the other hand, there are other threads participating in the same task you are, such as in a desktop graphical application, then you would be sharing the session with those threads, but you also will have implemented a proper locking scheme (or your graphical framework does) so that those threads do not collide.
-### Session Attributes {@name=attributes}
-
-The session provides a set of attributes and collection-oriented methods which allow you to view the current state of the session.
-
-The **identity map** is accessed by the `identity_map` attribute, which provides a dictionary interface. The keys are "identity keys", which are attached to all persistent objects by the attribute `_instance_key`:
-
- {python}
- >>> myobject._instance_key
- (<class 'test.tables.User'>, (7,))
-
- >>> myobject._instance_key in session.identity_map
- True
-
- >>> session.identity_map.values()
- [<__main__.User object at 0x712630>, <__main__.Address object at 0x712a70>]
-
-The identity map is a weak-referencing dictionary by default. This means that objects which are dereferenced on the outside will be removed from the session automatically. Note that objects which are marked as "dirty" will not fall out of scope until after changes on them have been flushed; special logic kicks in at the point of auto-removal which ensures that no pending changes remain on the object, else a temporary strong reference is created to the object.
-
-Some people prefer objects to stay in the session until explicitly removed in all cases; for this, you can specify the flag `weak_identity_map=False` to the `create_session` or `sessionmaker` functions so that the `Session` will use a regular dictionary.
-
-While the `identity_map` accessor is currently the actual dictionary used by the `Session` to store instances, you should not add or remove items from this dictionary. Use the session methods `save_or_update()` and `expunge()` to add or remove items.
-
-The Session also supports an iterator interface in order to see all objects in the identity map:
-
- {python}
- for obj in session:
- print obj
-
-As well as `__contains__()`:
-
- {python}
- if obj in session:
- print "Object is present"
-
-The session is also keeping track of all newly created (i.e. pending) objects, all objects which have had changes since they were last loaded or saved (i.e. "dirty"), and everything that's been marked as deleted.
-
- {python}
- # pending objects recently added to the Session
- session.new
-
- # persistent objects which currently have changes detected
- # (this collection is now created on the fly each time the property is called)
- session.dirty
-
- # persistent objects that have been marked as deleted via session.delete(obj)
- session.deleted
-
### Querying
-The `query()` function takes one or more classes and/or mappers, along with an optional `entity_name` parameter, and returns a new `Query` object which will issue mapper queries within the context of this Session. For each mapper is passed, the Query uses that mapper. For each class, the Query will locate the primary mapper for the class using `class_mapper()`.
+The `query()` function takes one or more *entities* and returns a new `Query` object which will issue mapper queries within the context of this Session. An entity is defined as a mapped class, a `Mapper` object, an orm-enabled *descriptor*, or an `AliasedClass` object (a future release will also include an `Entity` object for use with entity_name mappers).
{python}
# query from a class
session.query(User).filter_by(name='ed').all()
# query with multiple classes, returns tuples
- session.query(User).add_entity(Address).join('addresses').filter_by(name='ed').all()
+ session.query(User, Address).join('addresses').filter_by(name='ed').all()
+
+ # query using orm-enabled descriptors
+ session.query(User.name, User.fullname).all()
# query from a mapper
- query = session.query(usermapper)
- x = query.get(1)
-
- # query from a class mapped with entity name 'alt_users'
- q = session.query(User, entity_name='alt_users')
- y = q.options(eagerload('orders')).all()
-
-`entity_name` is an optional keyword argument sent with a class object, in order to further qualify which primary mapper to be used; this only applies if there was a `Mapper` created with that particular class/entity name combination, else an exception is raised. All of the methods on Session which take a class or mapper argument also take the `entity_name` argument, so that a given class can be properly matched to the desired primary mapper.
+ user_mapper = class_mapper(User)
+ session.query(user_mapper)
-All instances retrieved by the returned `Query` object will be stored as persistent instances within the originating `Session`.
+When `Query` returns results, each object instantiated is stored within the identity map. When a row matches an object which is already present, the same object is returned. In the latter case, whether or not the row is populated onto an existing object depends upon whether the attributes of the instance have been *expired* or not. As of 0.5, a default-configured `Session` automatically expires all instances along transaction boundaries, so that with a normally isolated transaction, there shouldn't be any issue of instances representing data which is stale with regards to the current transaction.
-### Saving New Instances
+### Adding New or Existing Items
-`save()` is called with a single transient instance as an argument, which is then added to the Session and becomes pending. When the session is next flushed, the instance will be saved to the database. If the given instance is not transient, meaning it is either attached to an existing Session or it has a database identity, an exception is raised.
+`add()` is used to place instances in the session. For *transient* (i.e. brand new) instances, this will have the effect of an INSERT taking place for those instances upon the next flush. For instances which are *persistent* (i.e. were loaded by this session), they are already present and do not need to be added. Instances which are *detached* (i.e. have been removed from a session) may be re-associated with a session using this method:
{python}
user1 = User(name='user1')
user2 = User(name='user2')
- session.save(user1)
- session.save(user2)
+ session.add(user1)
+ session.add(user2)
session.commit() # write changes to the database
-There's also other ways to have objects saved to the session automatically; one is by using cascade rules, and the other is by using a contextual session. Both of these are described later.
+To add a list of items to the session at once, use `add_all()`:
-### Updating/Merging Existing Instances
+ {python}
+ session.add_all([item1, item2, item3])
-The `update()` method is used when you have a detached instance, and you want to put it back into a `Session`. Recall that "detached" means the object has a database identity.
+The `add()` operation **cascades** along the `save-update` cascade. For more details see the section [unitofwork_cascades](rel:unitofwork_cascades).
-Since `update()` is a little picky that way, most people use `save_or_update()`, which checks for an `_instance_key` attribute, and based on whether it's there or not, calls either `save()` or `update()`:
+### Merging
- {python}
- # load user1 using session 1
- user1 = sess1.query(User).get(5)
-
- # remove it from session 1
- sess1.expunge(user1)
-
- # move it into session 2
- sess2.save_or_update(user1)
+`merge()` reconciles the current state of an instance and its associated children with existing data in the database, and returns a copy of the instance associated with the session. Usage is as follows:
-`update()` is also an operation that can happen automatically using cascade rules, just like `save()`.
+ {python}
+ merged_object = session.merge(existing_object)
-`merge()` on the other hand is a little like `update()`, except it creates a **copy** of the given instance in the session, and returns to you that instance; the instance you send it never goes into the session. `merge()` is much fancier than `update()`; it will actually look to see if an object with the same primary key is already present in the session, and if not will load it by primary key. Then, it will merge the attributes of the given object into the one which it just located.
+When given an instance, it follows these steps:
-This method is useful for bringing in objects which may have been restored from a serialization, such as those stored in an HTTP session, where the object may be present in the session already:
+ * It examines the primary key of the instance. If it's present, it attempts to load an instance with that primary key (or pulls from the local identity map).
+ * If there's no primary key on the given instance, or the given primary key does not exist in the database, a new instance is created.
+ * The state of the given instance is then copied onto the located/newly created instance.
+ * The operation is cascaded to associated child items along the `merge` cascade. Note that all changes present on the given instance, including changes to collections, are merged.
+ * The new instance is returned.
- {python}
- # deserialize an object
- myobj = pickle.loads(mystring)
+With `merge()`, the given instance is not placed within the session, and can be associated with a different session or detached. `merge()` is very useful for taking the state of any kind of object structure without regard for its origins or current session associations and placing that state within a session. Here's two examples:
- # "merge" it. if the session already had this object in the
- # identity map, then you get back the one from the current session.
- myobj = session.merge(myobj)
+ * An application which reads an object structure from a file and wishes to save it to the database might parse the file, build up the structure, and then use `merge()` to save it to the database, ensuring that the data within the file is used to formulate the primary key of each element of the structure. Later, when the file has changed, the same process can be re-run, producing a slightly different object structure, which can then be `merged()` in again, and the `Session` will automatically update the database to reflect those changes.
+ * A web application stores mapped entities within an HTTP session object. When each request starts up, the serialized data can be merged into the session, so that the original entity may be safely shared among requests and threads.
-`merge()` includes an important option called `dont_load`. When this boolean flag is set to `True`, the merge of a detached object will not force a `get()` of that object from the database. Normally, `merge()` issues a `get()` for every existing object so that it can load the most recent state of the object, which is then modified according to the state of the given object. With `dont_load=True`, the `get()` is skipped and `merge()` places an exact copy of the given object in the session. This allows objects which were retrieved from a caching system to be copied back into a session without any SQL overhead being added.
+`merge()` is frequently used by applications which implement their own second level caches. This refers to an application which uses an in memory dictionary, or an tool like Memcached to store objects over long running spans of time. When such an object needs to exist within a `Session`, `merge()` is a good choice since it leaves the original cached object untouched. For this use case, merge provides a keyword option called `dont_load=True`. When this boolean flag is set to `True`, `merge()` will not issue any SQL to reconcile the given object against the current state of the database, thereby reducing query overhead. The limitation is that the given object and all of its children may not contain any pending changes, and it's also of course possible that newer information in the database will not be present on the merged object, since no load is issued.
### Deleting
### Flushing
-This is the main gateway to what the `Session` does best, which is save everything ! It should be clear by now what a flush looks like:
+When the `Session` is used with its default configuration, the flush step is nearly always done transparently. Specifically, the flush occurs before any individual `Query` is issued, as well as within the `commit()` call before the transaction is committed. This behavior can be disabled by constructing `sessionmaker()` with the flag `autoflush=False`.
+
+Regardless of the autoflush setting, a flush can always be forced by issing `flush()`:
{python}
session.flush()
-It also can be called with a list of objects; in this form, the flush operation will be limited only to the objects specified in the list:
+`flush()` also supports the ability to flush a subset of objects which are present in the session, by passing a list of objects:
{python}
# saves only user1 and address2. all other modified
# objects remain present in the session.
session.flush([user1, address2])
-This second form of flush should be used carefully as it will not necessarily locate other dependent objects within the session, whose database representation may have foreign constraint relationships with the objects being operated upon.
-
-Theres also a way to have `flush()` called automatically before each query; this is called "autoflush" and is described below.
-
-Note that when using a `Session` that has been placed into a transaction, the `commit()` method will also `flush()` the `Session` unconditionally before committing the transaction.
+This second form of flush should be used carefully as it currently does not cascade, meaning that it will not necessarily affect other objects directly associated with the objects given.
-Note that flush **does not change** the state of any collections or entity relationships in memory; for example, if you set a foreign key attribute `b_id` on object `A` with the identifier `B.id`, the change will be flushed to the database, but `A` will not have `B` added to its collection. If you want to manipulate foreign key attributes directly, `refresh()` or `expire()` the objects whose state needs to be refreshed subsequent to flushing.
+The flush process *always* occurs within a transaction, even if the `Session` has been configured with `autocommit=True`, a setting that disables the session's persistent transactional state. If no transaction is present, `flush()` creates its own transaction and commits it. Any failures during flush will always result in a rollback of whatever transaction is present.
-### Autoflush
+### Committing
-A session can be configured to issue `flush()` calls before each query. This allows you to immediately have DB access to whatever has been saved to the session. It's recommended to use autoflush with `transactional=True`, that way an unexpected flush call won't permanently save to the database:
+`commit()` is used to commit the current transaction. It always issues `flush()` beforehand to flush any remaining state to the database; this is independent of the "autoflush" setting. If no transaction is present, it raises an error. Note that the default behavior of the `Session` is that a transaction is always present; this behavior can be disabled by setting `autocommit=True`. In autocommit mode, a transaction can be initiated by calling the `begin()` method.
- {python}
- Session = sessionmaker(autoflush=True, transactional=True)
- sess = Session()
- u1 = User(name='jack')
- sess.save(u1)
-
- # reload user1
- u2 = sess.query(User).filter_by(name='jack').one()
- assert u2 is u1
+Another behavior of `commit()` is that by default it expires the state of all instances present after the commit is complete. This is so that when the instances are next accessed, either through attribute access or by them being present in a `Query` result set, they receive the most recent state. To disable this behavior, configure `sessionmaker()` with `autoexpire=False`.
- # commit session, flushes whatever is remaining
- sess.commit()
+Normally, instances loaded into the `Session` are never changed by subsequent queries; the assumption is that the current transaction is isolated so the state most recently loaded is correct as long as the transaction continues. Setting `autocommit=True` works against this model to some degree since the `Session` behaves in exactly the same way with regard to attribute state, except no transaction is present.
-Autoflush is particularly handy when using "dynamic" mapper relations, so that changes to the underlying collection are immediately available via its query interface.
+### Rolling Back
-### Committing
+`rollback()` rolls back the current transaction. With a default configured session, the post-rollback state of the session is as follows:
-The `commit()` method on `Session` is used specifically when the `Session` is in a transactional state. The two ways that a session may be placed in a transactional state are to create it using the `transactional=True` option, or to call the `begin()` method.
+ * All connections are rolled back and returned to the connection pool, unless the Session was bound directly to
+ a Connection, in which case the connection is still maintained (but still rolled back).
+ * Objects which were initially in the *pending* state when they were added to the `Session` within the lifespan of the transaction are expunged, corresponding to their INSERT statement being rolled back. The state of their attributes remains unchanged.
+ * Objects which were marked as *deleted* within the lifespan of the transaction are promoted back to the *persistent* state, corresponding to their DELETE statement being rolled back. Note that if those objects were first *pending* within the transaction, that operation takes precedence instead.
+ * All objects not expunged are fully expired. This aspect of the behavior may be disabled by configuring `sessionmaker()` with `autoexpire=False`.
-`commit()` serves **two** purposes; it issues a `flush()` unconditionally to persist any remaining pending changes, and it issues a commit to all currently managed database connections. In the typical case this is just a single connection. After the commit, connection resources which were allocated by the `Session` are released. This holds true even for a `Session` which specifies `transactional=True`; when such a session is committed, the next transaction is not "begun" until the next database operation occurs.
+With that state understood, the `Session` may safely continue usage after a rollback occurs (note that this is a new feature as of version 0.5).
-See the section below on "Managing Transactions" for further detail.
+When a `flush()` fails, typically for reasons like primary key, foreign key, or "not nullable" constraint violations, a `rollback()` is issued automatically (it's currently not possible for a flush to continue after a partial failure). However, the flush process always uses its own transactional demarcator called a *subtransaction*, which is described more fully in the docstrings for `Session`. What it means here is that even though the database transaction has been rolled back, the end user must still issue `rollback()` to fully reset the state of the `Session`.
-### Expunge / Clear
+### Expunging
Expunge removes an object from the Session, sending persistent instances to the detached state, and pending instances to the transient state:
{python}
session.expunge(obj1)
-Use `expunge` when you'd like to remove an object altogether from memory, such as before calling `del` on it, which will prevent any "ghost" operations occurring when the session is flushed.
-
-This `clear()` method is equivalent to `expunge()`-ing everything from the Session:
-
- {python}
- session.clear()
-
-However note that the `clear()` method does not reset any transactional state or connection resources; therefore what you usually want to call instead of `clear()` is `close()`.
+To remove all items, call `session.expunge_all()`.
### Closing
-The `close()` method issues a `clear()`, and releases any transactional/connection resources. When connections are returned to the connection pool, whatever transactional state exists is rolled back.
-
-When `close()` is called, the `Session` is in the same state as when it was first created, and is safe to be used again. `close()` is especially important when using a contextual session, which remains in memory after usage. By issuing `close()`, the session will be clean for the next request that makes use of it.
+The `close()` method issues a `expunge_alll()`, and releases any transactional/connection resources. When connections are returned to the connection pool, transactional state is rolled back as well.
### Refreshing / Expiring
session.expire(obj1, ['hello', 'world'])
session.expire(obj2, ['hello', 'world'])
+The full contents of the session may be expired at once using `expire_all()`:
+
+ {python}
+ session.expire_all()
+
+`refresh()` and `expire()` are usually not needed when working with a default-configured `Session`. The usual need is when an UPDATE or DELETE has been issued manually within the transaction using `Session.execute()`.
+
+### Session Attributes {@name=attributes}
+
+The `Session` itself acts somewhat like a set-like collection. All items present may be accessed using the iterator interface:
+
+ {python}
+ for obj in session:
+ print obj
+
+And presence may be tested for using regular "contains" semantics:
+
+ {python}
+ if obj in session:
+ print "Object is present"
+
+The session is also keeping track of all newly created (i.e. pending) objects, all objects which have had changes since they were last loaded or saved (i.e. "dirty"), and everything that's been marked as deleted.
+
+ {python}
+ # pending objects recently added to the Session
+ session.new
+
+ # persistent objects which currently have changes detected
+ # (this collection is now created on the fly each time the property is called)
+ session.dirty
+
+ # persistent objects that have been marked as deleted via session.delete(obj)
+ session.deleted
+
+Note that objects within the session are by default *weakly referenced*. This means that when they are dereferenced in the outside application, they fall out of scope from within the `Session` as well and are subject to garbage collection by the Python interpreter. The exceptions to this include objects which are pending, objects which are marked as deleted, or persistent objects which have pending changes on them. After a full flush, these collections are all empty, and all objects are again weakly referenced. To disable the weak referencing behavior and force all objects within the session to remain until explicitly expunged, configure `sessionmaker()` with the `weak_identity_map=False` setting.
+
## Cascades
Mappers support the concept of configurable *cascade* behavior on `relation()`s. This behavior controls how the Session should treat the instances that have a parent-child relationship with another instance that is operated upon by the Session. Cascade is indicated as a comma-separated list of string keywords, with the possible values `all`, `delete`, `save-update`, `refresh-expire`, `merge`, `expunge`, and `delete-orphan`.
'customer' : relation(User, users_table, user_orders_table, cascade="save-update"),
})
-The above mapper specifies two relations, `items` and `customer`. The `items` relationship specifies "all, delete-orphan" as its `cascade` value, indicating that all `save`, `update`, `merge`, `expunge`, `refresh` `delete` and `expire` operations performed on a parent `Order` instance should also be performed on the child `Item` instances attached to it (`save` and `update` are cascaded using the `save_or_update()` method, so that the database identity of the instance doesn't matter). The `delete-orphan` cascade value additionally indicates that if an `Item` instance is no longer associated with an `Order`, it should also be deleted. The "all, delete-orphan" cascade argument allows a so-called *lifecycle* relationship between an `Order` and an `Item` object.
-
-The `customer` relationship specifies only the "save-update" cascade value, indicating most operations will not be cascaded from a parent `Order` instance to a child `User` instance, except for if the `Order` is attached with a particular session, either via the `save()`, `update()`, or `save-update()` method.
-
-Additionally, when a child item is attached to a parent item that specifies the "save-update" cascade value on the relationship, the child is automatically passed to `save_or_update()` (and the operation is further cascaded to the child item).
+The above mapper specifies two relations, `items` and `customer`. The `items` relationship specifies "all, delete-orphan" as its `cascade` value, indicating that all `add`, `merge`, `expunge`, `refresh` `delete` and `expire` operations performed on a parent `Order` instance should also be performed on the child `Item` instances attached to it. The `delete-orphan` cascade value additionally indicates that if an `Item` instance is no longer associated with an `Order`, it should also be deleted. The "all, delete-orphan" cascade argument allows a so-called *lifecycle* relationship between an `Order` and an `Item` object.
-Note that cascading doesn't do anything that isn't possible by manually calling Session methods on individual instances within a hierarchy, it merely automates common operations on a group of associated instances.
+The `customer` relationship specifies only the "save-update" cascade value, indicating most operations will not be cascaded from a parent `Order` instance to a child `User` instance except for the `add()` operation. "save-update" cascade indicates that an `add()` on the parent will casade to all child items, and also that items added to a parent which is already present in the sessio will also be added.
The default value for `cascade` on `relation()`s is `save-update, merge`.
## Managing Transactions
-The Session can manage transactions automatically, including across multiple engines. When the Session is in a transaction, as it receives requests to execute SQL statements, it adds each individual Connection/Engine encountered to its transactional state. At commit time, all unflushed data is flushed, and each individual transaction is committed. If the underlying databases support two-phase semantics, this may be used by the Session as well if two-phase transactions are enabled.
+The `Session` manages transactions across all engines associated with it. As the `Session` receives requests to execute SQL statements using a particular `Engine` or `Connection`, it adds each individual `Engine` encountered to its transactional state and maintains an open connection for each one (note that a simple application normally has just one `Engine`). At commit time, all unflushed data is flushed, and each individual transaction is committed. If the underlying databases support two-phase semantics, this may be used by the Session as well if two-phase transactions are enabled.
-The easiest way to use a Session with transactions is just to declare it as transactional. The session will remain in a transaction at all times:
+Normal operation ends the transactional state using the `rolback()` or `commit()` methods. After either is called, the `Session` starts a new transaction.
{python}
- # transactional session
- Session = sessionmaker(transactional=True)
+ Session = sessionmaker()
sess = Session()
try:
item1 = sess.query(Item).get(1)
# rollback - will immediately go into a new transaction afterwards.
sess.rollback()
-Things to note above:
-
- * When using a transactional session, either a `rollback()` or a `close()` call **is required** when an error is raised by `flush()` or `commit()`. The `flush()` error condition will issue a ROLLBACK to the database automatically, but the state of the `Session` itself remains in an "undefined" state until the user decides whether to rollback or close.
- * The `commit()` call unconditionally issues a `flush()`. Particularly when using `transactional=True` in conjunction with `autoflush=True`, explicit `flush()` calls are usually not needed.
-
-Alternatively, a transaction can be begun explicitly using `begin()`:
+A session which is configured with `autocommit=True` may be placed into a transaction using `begin()`. With an `autocommit=True` session that's been placed into a transaction using `begin()`, the session releases all connection resources after a `commit()` or `rollback()` and remains transaction-less (with the exception of flushes) until the next `begin()` call:
{python}
- # non transactional session
- Session = sessionmaker(transactional=False)
+ Session = sessionmaker(autocommit=True)
sess = Session()
sess.begin()
try:
sess.rollback()
raise
-Like the `transactional` example, the same rules apply; an explicit `rollback()` or `close()` is required when an error occurs, and the `commit()` call issues a `flush()` as well.
-
-Session also supports Python 2.5's with statement so that the example above can be written as:
+The `begin()` method also returns a transactional token which is compatible with the Python 2.6 `with` statement:
{python}
- Session = sessionmaker(transactional=False)
+ Session = sessionmaker(autocommit=True)
sess = Session()
with sess.begin():
item1 = sess.query(Item).get(1)
item1.foo = 'bar'
item2.bar = 'foo'
-Subtransactions can be created by calling the `begin()` method repeatedly. For each transaction you `begin()` you must always call either `commit()` or `rollback()`. Note that this includes the implicit transaction created by the transactional session. When a subtransaction is created the current transaction of the session is set to that transaction. Commiting the subtransaction will return you to the next outer transaction. Rolling it back will also return you to the next outer transaction, but in addition it will roll back database state to the innermost transaction that supports rolling back to. Usually this means the root transaction, unless you use the nested transaction functionality via the `begin_nested()` method. MySQL and Postgres (and soon Oracle) support using "nested" transactions by creating SAVEPOINTs, :
+SAVEPOINT transactions, if supported by the underlying engine, may be delineated using the `begin_nested()` method:
{python}
- Session = sessionmaker(transactional=False)
+ Session = sessionmaker()
sess = Session()
- sess.begin()
- sess.save(u1)
- sess.save(u2)
- sess.flush()
+ sess.add(u1)
+ sess.add(u2)
sess.begin_nested() # establish a savepoint
- sess.save(u3)
+ sess.add(u3)
sess.rollback() # rolls back u3, keeps u1 and u2
sess.commit() # commits u1 and u2
+`begin_nested()` may be called any number of times, which will issue a new SAVEPOINT with a unique identifier for each call. For each `begin_nested()` call, a corresponding `rollback()` or `commit()` must be issued.
+
+When `begin_nested()` is called, a `flush()` is unconditionally issued (regardless of the `autoflush` setting). This is so that when a `rollback()` occurs, the full state of the session is expired, thus causing all subsequent attribute/instance access to reference the full state of the `Session` right before `begin_nested()` was called.
+
Finally, for MySQL, Postgres, and soon Oracle as well, the session can be instructed to use two-phase commit semantics. This will coordinate the commiting of transactions across databases so that the transaction is either committed or rolled back in all databases. You can also `prepare()` the session for interacting with transactions not managed by SQLAlchemy. To use two phase transactions set the flag `twophase=True` on the session:
{python}
engine1 = create_engine('postgres://db1')
engine2 = create_engine('postgres://db2')
- Session = sessionmaker(twophase=True, transactional=True)
+ Session = sessionmaker(twophase=True)
# bind User operations to engine 1, Account operations to engine 2
Session.configure(binds={User:engine1, Account:engine2})
# before committing both transactions
sess.commit()
-Be aware that when a crash occurs in one of the databases while the the transactions are prepared you have to manually commit or rollback the prepared transactions in your database as appropriate.
-
## Embedding SQL Insert/Update Expressions into a Flush {@name=flushsql}
This feature allows the value of a database column to be set to a SQL expression instead of a literal value. It's especially useful for atomic updates, calling stored procedures, etc. All you do is assign an expression to an attribute:
# issues "UPDATE some_table SET value=value+1"
session.commit()
-This works both for INSERT and UPDATE statements. After the flush/commit operation, the `value` attribute on `someobject` gets "deferred", so that when you again access it the newly generated value will be loaded from the database. This is the same mechanism at work when database-side column defaults fire off.
+This technique works both for INSERT and UPDATE statements. After the flush/commit operation, the `value` attribute on `someobject` above is expired, so that when next accessed the newly generated value will be loaded from the database.
## Using SQL Expressions with Sessions {@name=sql}
-SQL constructs and string statements can be executed via the `Session`. You'd want to do this normally when your `Session` is transactional and you'd like your free-standing SQL statements to participate in the same transaction.
-
-The two ways to do this are to use the connection/execution services of the Session, or to have your Session participate in a regular SQL transaction.
-
-First, a Session thats associated with an Engine or Connection can execute statements immediately (whether or not it's transactional):
+SQL expressions and strings can be executed via the `Session` within its transactional context. This is most easily accomplished using the `execute()` method, which returns a `ResultProxy` in the same manner as an `Engine` or `Connection`:
{python}
- Session = sessionmaker(bind=engine, transactional=True)
+ Session = sessionmaker(bind=engine)
sess = Session()
+
+ # execute a string statement
result = sess.execute("select * from table where id=:id", {'id':7})
- result2 = sess.execute(select([mytable], mytable.c.id==7))
+
+ # execute a SQL expression construct
+ result = sess.execute(select([mytable]).where(mytable.c.id==7))
-To get at the current connection used by the session, which will be part of the current transaction if one is in progress, use `connection()`:
+The current `Connection` held by the `Session` is accessible using the `connection()` method:
{python}
connection = sess.connection()
-
-A second scenario is that of a Session which is not directly bound to a connectable. This session executes statements relative to a particular `Mapper`, since the mappers are bound to tables which are in turn bound to connectables via their `MetaData` (either the session or the mapped tables need to be bound). In this case, the Session can conceivably be associated with multiple databases through different mappers; so it wants you to send along a `mapper` argument, which can be any mapped class or mapper instance:
+The examples above deal with a `Session` that's bound to a single `Engine` or `Connection`. To execute statements using a `Session` which is bound either to multiple engines, or none at all (i.e. relies upon bound metadata), both `execute()` and `connection()` accept a `mapper` keyword argument, which is passed a mapped class or `Mapper` instance, which is used to locate the proper context for the desired engine:
+
{python}
- # session is *not* bound to an engine or connection
- Session = sessionmaker(transactional=True)
+ Session = sessionmaker()
sess = Session()
# need to specify mapper or class when executing
result = sess.execute("select * from table where id=:id", {'id':7}, mapper=MyMappedClass)
- result2 = sess.execute(select([mytable], mytable.c.id==7), mapper=MyMappedClass)
- # need to specify mapper or class when you call connection()
+ result = sess.execute(select([mytable], mytable.c.id==7), mapper=MyMappedClass)
+
connection = sess.connection(MyMappedClass)
-The third scenario is when you are using `Connection` and `Transaction` yourself, and want the `Session` to participate. This is easy, as you just bind the `Session` to the connection:
+## Joining a Session into an External Transaction {@name=joining}
+
+If a `Connection` is being used which is already in a transactional state (i.e. has a `Transaction`), a `Session` can be made to participate within that transaction by just binding the `Session` to that `Connection`:
{python}
- # non-transactional session
- Session = sessionmaker(transactional=False)
+ Session = sessionmaker()
# non-ORM connection + transaction
conn = engine.connect()
trans = conn.begin()
- # bind the Session *instance* to the connection
+ # create a Session, bind to the connection
sess = Session(bind=conn)
- # ... etc
+ # ... work with session
- trans.commit()
+ sess.commit() # commit the session
+ sess.close() # close it out, prohibit further actions
-It's safe to use a `Session` which is transactional or autoflushing, as well as to call `begin()`/`commit()` on the session too; the outermost Transaction object, the one we declared explicitly, controls the scope of the transaction.
+ trans.commit() # commit the actual transaction
-When using the `threadlocal` engine context, things are that much easier; the `Session` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it:
+Note that above, we issue a `commit()` both on the `Session` as well as the `Transaction`. This is an example of where we take advantage of `Connection`'s ability to maintain *subtransactions*, or nested begin/commit pairs. The `Session` is used exactly as though it were managing the transaction on its own; its `commit()` method issues its `flush()`, and commits the subtransaction. The subsequent transaction the `Session` starts after commit will not begin until it's next used. Above we issue a `close()` to prevent this from occuring. Finally, the actual transaction is committed using `Transaction.commit()`.
+
+When using the `threadlocal` engine context, the process above is simplified; the `Session` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it:
{python}
engine = create_engine('postgres://mydb', strategy="threadlocal")
{python}
from sqlalchemy.orm import scoped_session, sessionmaker
- Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+ Session = scoped_session(sessionmaker())
However, when you instantiate this `Session` "class", in reality the object is pulled from a threadlocal variable, or if it doesn't exist yet, it's created using the underlying class generated by `sessionmaker()`:
u2 = User()
# save to the contextual session, without instantiating
- Session.save(u1)
- Session.save(u2)
+ Session.add(u1)
+ Session.add(u2)
# view the "new" attribute
assert u1 in Session.new
- # flush changes (if not using autoflush)
- Session.flush()
-
- # commit transaction (if using a transactional session)
+ # commit changes
Session.commit()
-To "dispose" of the `Session`, there's two general approaches. One is to close out the current session, but to leave it assigned to the current context. This allows the same object to be re-used on another operation. This may be called from a current, instantiated `Session`:
-
- {python}
- sess.close()
-
-Or, when using `scoped_session()`, the `close()` method may also be called as a classmethod on the `Session` "class":
-
- {python}
- Session.close()
-
-When the `Session` is closed, it remains attached, but clears all of its contents and releases any ongoing transactional resources, including rolling back any remaining transactional state. The `Session` can then be used again.
-
-The other method is to remove the current session from the current context altogether. This is accomplished using the classmethod `remove()`:
+The contextual session may be disposed of by calling `Session.remove()`:
{python}
+ # remove current contextual session
Session.remove()
-
-After `remove()` is called, the next call to `Session()` will create a *new* `Session` object which then becomes the contextual session.
-That, in a nutshell, is all there really is to it. Now for all the extra things one should know.
+After `remove()` is called, the next operation with the contextual session will start a new `Session` for the current thread.
### Lifespan of a Contextual Session {@name=lifespan}
# some other code calls Session, it's the
# same contextual session as "sess"
sess2 = Session()
- sess2.save(foo)
+ sess2.add(foo)
sess2.commit()
# generate content to be returned
Session.remove() <-
web response <-
-Above, we illustrate a *typical* organization of duties, where the "Web Framework" layer has some integration built-in to manage the span of ORM sessions. Upon the initial handling of an incoming web request, the framework passes control to a controller. The controller then calls `Session()` when it wishes to work with the ORM; this method establishes the contextual Session which will remain until it's removed. Disparate parts of the controller code may all call `Session()` and will get the same session object. Then, when the controller has completed and the response is to be sent to the web server, the framework **closes out** the current contextual session, above using the `remove()` method which removes the session from the context altogether.
-
-As an alternative, the "finalization" step can also call `Session.close()`, which will leave the same session object in place. Which one is better ? For a web framework which runs from a fixed pool of threads, it doesn't matter much. For a framework which runs a **variable** number of threads, or which **creates and disposes** of a thread for each request, `remove()` is better, since it leaves no resources associated with the thread which might not exist.
-
-* Why close out the session at all ? Why not just leave it going so the next request doesn't have to do as many queries ?
-
- There are some cases where you may actually want to do this. However, this is a special case where you are dealing with data which **does not change** very often, or you don't care about the "freshness" of the data. In reality, a single thread of a web server may, on a slow day, sit around for many minutes or even hours without being accessed. When it's next accessed, if data from the previous request still exists in the session, that data may be very stale indeed. So it's generally better to have an empty session at the start of a web request.
-
-### Associating Classes and Mappers with a Contextual Session {@name=associating}
-
-Another luxury we gain, when we've established a `Session()` that can be globally accessed, is the ability for mapped classes and objects to provide us with session-oriented functionality automatically. When using the `scoped_session()` function, we access this feature using the `mapper` attribute on the object in place of the normal `sqlalchemy.orm.mapper` function:
-
- {python}
- # "contextual" mapper function
- mapper = Session.mapper
-
- # use normally
- mapper(User, users_table, properties={
- relation(Address)
- })
- mapper(Address, addresses_table)
-
-When we use the contextual `mapper()` function, our `User` and `Address` now gain a new attribute `query`, which will create a `Query` object for us against the contextual session:
-
- {python}
- wendy = User.query.filter_by(name='wendy').one()
-
-#### Auto-Save Behavior with Contextual Session's Mapper {@name=autosave}
-
-By default, when using Session.mapper, **new instances are saved into the contextual session automatically upon construction;** there is no longer a need to call `save()`:
-
- {python}
- >>> newuser = User(name='ed')
- >>> assert newuser in Session.new
- True
-
-The auto-save functionality can cause problems, namely that any `flush()` which occurs before a newly constructed object is fully populated will result in that object being INSERTed without all of its attributes completed. As a `flush()` is more frequent when using sessions with `autoflush=True`, **the auto-save behavior can be disabled**, using the `save_on_init=False` flag:
-
- {python}
- # "contextual" mapper function
- mapper = Session.mapper
+The above example illustrates an explicit call to `Session.remove()`. This has the effect such that each web request starts fresh with a brand new session. When integrating with a web framework, there's actually many options on how to proceed for this step, particularly as of version 0.5:
- # use normally, specify no save on init:
- mapper(User, users_table, properties={
- relation(Address)
- }, save_on_init=False)
- mapper(Address, addresses_table, save_on_init=False)
+ * Session.remove() - this is the most cut and dry approach; the `Session` is thrown away, all of its transactional/connection resources are closed out, everything within it is explicitly gone. A new `Session` will be used on the next request.
+ * Session.close() - Similar to calling `remove()`, in that all objects are explicitly expunged and all transactional/connection resources closed, except the actual `Session` object hangs around. It doesn't make too much difference here unless the start of the web request would like to pass specific options to the initial construction of `Session()`, such as a specific `Engine` to bind to.
+ * Session.commit() - In this case, the behavior is that any remaining changes pending are flushed, and the transaction is committed. The full state of the session is expired, so that when the next web request is started, all data will be reloaded. In reality, the contents of the `Session` are weakly referenced anyway so its likely that it will be empty on the next request in any case.
+ * Session.rollback() - Similar to calling commit, except we assume that the user would have called commit explicitly if that was desired; the `rollback()` ensures that no transactional state remains and expires all data, in the case that the request was aborted and did not roll back itself.
+ * do nothing - this is a valid option as well. The controller code is responsible for doing one of the above steps at the end of the request.
- # objects now again require explicit "save"
- >>> newuser = User(name='ed')
- >>> assert newuser in Session.new
- False
-
- >>> Session.save(newuser)
- >>> assert newuser in Session.new
- True
-
-The functionality of `Session.mapper` is an updated version of what used to be accomplished by the `assignmapper()` SQLAlchemy extension.
-
[Generated docstrings for scoped_session()](rel:docstrings_sqlalchemy.orm_modfunc_scoped_session)
## Partitioning Strategies
engine1 = create_engine('postgres://db1')
engine2 = create_engine('postgres://db2')
- Session = sessionmaker(twophase=True, transactional=True)
+ Session = sessionmaker(twophase=True)
# bind User operations to engine 1, Account operations to engine 2
Session.configure(binds={User:engine1, Account:engine2})
## Version Check
-A quick check to verify that we are on at least **version 0.4** of SQLAlchemy:
+A quick check to verify that we are on at least **version 0.5** of SQLAlchemy:
{python}
>>> import sqlalchemy
>>> sqlalchemy.__version__ # doctest:+SKIP
- 0.4.0
+ 0.5.0
## Connecting
>>> metadata = MetaData()
>>> users = Table('users', metadata,
... Column('id', Integer, primary_key=True),
- ... Column('name', String(40)),
- ... Column('fullname', String(100)),
+ ... Column('name', String),
+ ... Column('fullname', String),
... )
>>> addresses = Table('addresses', metadata,
... Column('id', Integer, primary_key=True),
... Column('user_id', None, ForeignKey('users.id')),
- ... Column('email_address', String(50), nullable=False)
+ ... Column('email_address', String, nullable=False)
... )
All about how to define `Table` objects, as well as how to create them from an existing database automatically, is described in [metadata](rel:metadata).
{}
CREATE TABLE users (
id INTEGER NOT NULL,
- name VARCHAR(40),
- fullname VARCHAR(100),
+ name VARCHAR,
+ fullname VARCHAR,
PRIMARY KEY (id)
)
{}
CREATE TABLE addresses (
id INTEGER NOT NULL,
user_id INTEGER,
- email_address VARCHAR(50) NOT NULL,
+ email_address VARCHAR NOT NULL,
PRIMARY KEY (id),
FOREIGN KEY(user_id) REFERENCES users (id)
)
{}
COMMIT
+Users familiar with the syntax of CREATE TABLE may notice that the VARCHAR columns were generated without a length; on SQLite, this is a valid datatype, but on most databases it's not allowed. So if running this tutorial on a database such as Postgres or MySQL, and you wish to use SQLAlchemy to generate the tables, a "length" may be provided to the `String` type as below:
+
+ {python}
+ Column('name', String(50))
+
+The length field on `String`, as well as similar fields available on `Integer`, `Numeric`, etc. are not referenced by SQLAlchemy other than when creating tables.
+
## Insert Expressions
The first SQL expression we'll create is the `Insert` construct, which represents an INSERT statement. This is typically created relative to its target table:
{python}
>>> print users.c.id==7
- users.id = :users_id_1
+ users.id = :id_1
The `7` literal is embedded in `ClauseElement`; we can use the same trick we did with the `Insert` object to see it:
{python}
>>> (users.c.id==7).compile().params
- {'users_id_1': 7}
+ {'id_1': 7}
Most Python operators, as it turns out, produce a SQL expression here, like equals, not equals, etc.:
{python}
>>> print users.c.id != 7
- users.id != :users_id_1
+ users.id != :id_1
>>> # None converts to IS NULL
>>> print users.c.name == None
>>> # reverse works too
>>> print 'fred' > users.c.name
- users.name < :users_name_1
+ users.name < :name_1
If we add two integer columns together, we get an addition expression:
{python}
>>> print users.c.name.op('tiddlywinks')('foo')
- users.name tiddlywinks :users_name_1
+ users.name tiddlywinks :name_1
## Conjunctions {@name=conjunctions}
>>> print and_(users.c.name.like('j%'), users.c.id==addresses.c.user_id, #doctest: +NORMALIZE_WHITESPACE
... or_(addresses.c.email_address=='wendy@aol.com', addresses.c.email_address=='jack@yahoo.com'),
... not_(users.c.id>5))
- users.name LIKE :users_name_1 AND users.id = addresses.user_id AND
- (addresses.email_address = :addresses_email_address_1 OR addresses.email_address = :addresses_email_address_2)
- AND users.id <= :users_id_1
+ users.name LIKE :name_1 AND users.id = addresses.user_id AND
+ (addresses.email_address = :email_address_1 OR addresses.email_address = :email_address_2)
+ AND users.id <= :id_1
And you can also use the re-jiggered bitwise AND, OR and NOT operators, although because of Python operator precedence you have to watch your parenthesis:
>>> print users.c.name.like('j%') & (users.c.id==addresses.c.user_id) & \
... ((addresses.c.email_address=='wendy@aol.com') | (addresses.c.email_address=='jack@yahoo.com')) \
... & ~(users.c.id>5) # doctest: +NORMALIZE_WHITESPACE
- users.name LIKE :users_name_1 AND users.id = addresses.user_id AND
- (addresses.email_address = :addresses_email_address_1 OR addresses.email_address = :addresses_email_address_2)
- AND users.id <= :users_id_1
+ users.name LIKE :name_1 AND users.id = addresses.user_id AND
+ (addresses.email_address = :email_address_1 OR addresses.email_address = :email_address_2)
+ AND users.id <= :id_1
So with all of this vocabulary, let's select all users who have an email address at AOL or MSN, whose name starts with a letter between "m" and "z", and we'll also generate a column containing their full name combined with their email address. We will add two new constructs to this statement, `between()` and `label()`. `between()` produces a BETWEEN clause, and `label()` is used in a column expression to produce labels using the `AS` keyword; it's recommended when selecting from expressions that otherwise would not have a name:
{python}
>>> print users.join(addresses, addresses.c.email_address.like(users.c.name + '%'))
- users JOIN addresses ON addresses.email_address LIKE users.name || :users_name_1
+ users JOIN addresses ON addresses.email_address LIKE users.name || :name_1
When we create a `select()` construct, SQLAlchemy looks around at the tables we've mentioned and then places them in the FROM clause of the statement. When we use JOINs however, we know what FROM clause we want, so here we make usage of the `from_obj` keyword argument:
>>> print query
{opensql}SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, addresses_1.id AS addresses_1_id, addresses_1.user_id AS addresses_1_user_id, addresses_1.email_address AS addresses_1_email_address
FROM users LEFT OUTER JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id
- WHERE users.name = :users_name_1 AND (EXISTS (SELECT addresses_1.id
+ WHERE users.name = :name_1 AND (EXISTS (SELECT addresses_1.id
FROM addresses AS addresses_1
- WHERE addresses_1.user_id = users.id AND addresses_1.email_address LIKE :addresses_email_address_1)) ORDER BY users.fullname DESC
+ WHERE addresses_1.user_id = users.id AND addresses_1.email_address LIKE :email_address_1)) ORDER BY users.fullname DESC
One more thing though, with automatic labeling applied as well as anonymous aliasing, how do we retrieve the columns from the rows for this thing ? The label for the `email_addresses` column is now the generated name `addresses_1_email_address`; and in another statement might be something different ! This is where accessing by result columns by `Column` object becomes very useful:
To embed a SELECT in a column expression, use `as_scalar()`:
{python}
- {sql}>>> print conn.execute(select([
+ {sql}>>> print conn.execute(select([ # doctest: +NORMALIZE_WHITESPACE
... users.c.name,
... select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).as_scalar()
... ])).fetchall()
- SELECT users.name, (SELECT count(addresses.id)
+ SELECT users.name, (SELECT count(addresses.id) AS count_1
FROM addresses
WHERE users.id = addresses.user_id) AS anon_1
FROM users
Alternatively, applying a `label()` to a select evaluates it as a scalar as well:
{python}
- {sql}>>> print conn.execute(select([
+ {sql}>>> print conn.execute(select([ # doctest: +NORMALIZE_WHITESPACE
... users.c.name,
... select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).label('address_count')
... ])).fetchall()
- SELECT users.name, (SELECT count(addresses.id)
+ SELECT users.name, (SELECT count(addresses.id) AS count_1
FROM addresses
WHERE users.id = addresses.user_id) AS address_count
FROM users
>>> s = select([addresses.c.user_id, func.count(addresses.c.id)]).\
... group_by(addresses.c.user_id).having(func.count(addresses.c.id)>1)
{opensql}>>> print conn.execute(s).fetchall()
- SELECT addresses.user_id, count(addresses.id)
+ SELECT addresses.user_id, count(addresses.id) AS count_1
FROM addresses GROUP BY addresses.user_id
HAVING count(addresses.id) > ?
[1]
from sqlalchemy.sql import compiler, expression
from sqlalchemy.engine import default, strategies, threadlocal, url
import sqlalchemy.orm.shard
-import sqlalchemy.ext.sessioncontext as sessioncontext
-import sqlalchemy.ext.selectresults as selectresults
import sqlalchemy.ext.orderinglist as orderinglist
import sqlalchemy.ext.associationproxy as associationproxy
-import sqlalchemy.ext.assignmapper as assignmapper
import sqlalchemy.ext.sqlsoup as sqlsoup
import sqlalchemy.ext.declarative as declarative
else:
to_gen = files + post_files
-title='SQLAlchemy 0.4 Documentation'
+title='SQLAlchemy 0.5 Documentation'
version = options.version
import re
import doctest
import sqlalchemy.util as util
-import sqlalchemy.logging as salog
+import sqlalchemy.log as salog
import logging
salog.default_enabled=True
--- /dev/null
+"""this example illustrates how to replace SQLAlchemy's class descriptors with a user-defined system.
+
+This sort of thing is appropriate for integration with frameworks that redefine class behaviors
+in their own way, such that SQLA's default instrumentation is not compatible.
+
+The example illustrates redefinition of instrumentation at the class level as well as the collection
+level, and redefines the storage of the class to store state within "instance._goofy_dict" instead
+of "instance.__dict__". Note that the default collection implementations can be used
+with a custom attribute system as well.
+
+"""
+from sqlalchemy import *
+from sqlalchemy.orm import *
+
+from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented
+from sqlalchemy.orm.collections import collection_adapter
+
+
+class MyClassState(InstrumentationManager):
+ def __init__(self, cls):
+ self.states = {}
+
+ def instrument_attribute(self, class_, key, attr):
+ pass
+
+ def install_descriptor(self, class_, key, attr):
+ pass
+
+ def uninstall_descriptor(self, class_, key, attr):
+ pass
+
+ def instrument_collection_class(self, class_, key, collection_class):
+ return MyCollection
+
+ def get_instance_dict(self, class_, instance):
+ return instance._goofy_dict
+
+ def initialize_instance_dict(self, class_, instance):
+ instance.__dict__['_goofy_dict'] = {}
+
+ def initialize_collection(self, key, state, factory):
+ data = factory()
+ return MyCollectionAdapter(key, state, data), data
+
+ def install_state(self, class_, instance, state):
+ self.states[id(instance)] = state
+
+ def state_getter(self, class_):
+ def find(instance):
+ return self.states[id(instance)]
+ return find
+
+class MyClass(object):
+ __sa_instrumentation_manager__ = MyClassState
+
+ def __init__(self, **kwargs):
+ for k in kwargs:
+ setattr(self, k, kwargs[k])
+
+ def __getattr__(self, key):
+ if is_instrumented(self, key):
+ return get_attribute(self, key)
+ else:
+ try:
+ return self._goofy_dict[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __setattr__(self, key, value):
+ if is_instrumented(self, key):
+ set_attribute(self, key, value)
+ else:
+ self._goofy_dict[key] = value
+
+ def __delattr__(self, key):
+ if is_instrumented(self, key):
+ del_attribute(self, key)
+ else:
+ del self._goofy_dict[key]
+
+class MyCollectionAdapter(object):
+ """An wholly alternative instrumentation implementation."""
+ def __init__(self, key, state, collection):
+ self.key = key
+ self.state = state
+ self.collection = collection
+ setattr(collection, '_sa_adapter', self)
+
+ def unlink(self, data):
+ setattr(data, '_sa_adapter', None)
+
+ def adapt_like_to_iterable(self, obj):
+ return iter(obj)
+
+ def append_with_event(self, item, initiator=None):
+ self.collection.add(item, emit=initiator)
+
+ def append_without_event(self, item):
+ self.collection.add(item, emit=False)
+
+ def remove_with_event(self, item, initiator=None):
+ self.collection.remove(item, emit=initiator)
+
+ def remove_without_event(self, item):
+ self.collection.remove(item, emit=False)
+
+ def clear_with_event(self, initiator=None):
+ for item in list(self):
+ self.remove_with_event(item, initiator)
+ def clear_without_event(self):
+ for item in list(self):
+ self.remove_without_event(item)
+ def __iter__(self):
+ return iter(self.collection)
+
+ def fire_append_event(self, item, initiator=None):
+ if initiator is not False and item is not None:
+ self.state.get_impl(self.key).fire_append_event(self.state, item,
+ initiator)
+
+ def fire_remove_event(self, item, initiator=None):
+ if initiator is not False and item is not None:
+ self.state.get_impl(self.key).fire_remove_event(self.state, item,
+ initiator)
+
+ def fire_pre_remove_event(self, initiator=None):
+ self.state.get_impl(self.key).fire_pre_remove_event(self.state,
+ initiator)
+
+class MyCollection(object):
+ def __init__(self):
+ self.members = list()
+ def add(self, object, emit=None):
+ self.members.append(object)
+ collection_adapter(self).fire_append_event(object, emit)
+ def remove(self, object, emit=None):
+ collection_adapter(self).fire_pre_remove_event(object)
+ self.members.remove(object)
+ collection_adapter(self).fire_remove_event(object, emit)
+ def __getitem__(self, index):
+ return self.members[index]
+ def __iter__(self):
+ return iter(self.members)
+ def __len__(self):
+ return len(self.members)
+
+if __name__ == '__main__':
+ meta = MetaData(create_engine('sqlite://'))
+
+ table1 = Table('table1', meta, Column('id', Integer, primary_key=True), Column('name', Text))
+ table2 = Table('table2', meta, Column('id', Integer, primary_key=True), Column('name', Text), Column('t1id', Integer, ForeignKey('table1.id')))
+ meta.create_all()
+
+ class A(MyClass):
+ pass
+
+ class B(MyClass):
+ pass
+
+ mapper(A, table1, properties={
+ 'bs':relation(B)
+ })
+
+ mapper(B, table2)
+
+ a1 = A(name='a1', bs=[B(name='b1'), B(name='b2')])
+
+ assert a1.name == 'a1'
+ assert a1.bs[0].name == 'b1'
+ assert isinstance(a1.bs, MyCollection)
+
+ sess = create_session()
+ sess.save(a1)
+
+ sess.flush()
+ sess.clear()
+
+ a1 = sess.query(A).get(a1.id)
+
+ assert a1.name == 'a1'
+ assert a1.bs[0].name == 'b1'
+ assert isinstance(a1.bs, MyCollection)
+
+ a1.bs.remove(a1.bs[0])
+
+ sess.flush()
+ sess.clear()
+
+ a1 = sess.query(A).get(a1.id)
+ assert len(a1.bs) == 1
"""\r
\r
class MyProxyDict(object):\r
- def __init__(self, parent, collection_name, keyname):\r
+ def __init__(self, parent, collection_name, childclass, keyname):\r
self.parent = parent\r
self.collection_name = collection_name\r
+ self.childclass = childclass\r
self.keyname = keyname\r
\r
def collection(self):\r
collection = property(collection)\r
\r
def keys(self):\r
- # this can be improved to not query all columns\r
- return [getattr(x, self.keyname) for x in self.collection.all()]\r
+ descriptor = getattr(self.childclass, self.keyname)\r
+ return [x[0] for x in self.collection.values(descriptor)]\r
\r
def __getitem__(self, key):\r
x = self.collection.filter_by(**{self.keyname:key}).first()\r
_collection = dynamic_loader("MyChild", cascade="all, delete-orphan")\r
\r
def child_map(self):\r
- return MyProxyDict(self, '_collection', 'key')\r
+ return MyProxyDict(self, '_collection', MyChild, 'key')\r
child_map = property(child_map)\r
\r
class MyChild(Base):\r
\r
Base.metadata.create_all()\r
\r
-sess = create_session(autoflush=True, transactional=True)\r
+sess = sessionmaker()()\r
\r
p1 = MyParent(name='p1')\r
-sess.save(p1)\r
+sess.add(p1)\r
\r
p1.child_map['k1'] = k1 = MyChild(key='k1')\r
p1.child_map['k2'] = k2 = MyChild(key='k2')\r
\r
-\r
assert p1.child_map.keys() == ['k1', 'k2']\r
\r
assert p1.child_map['k1'] is k1\r
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import inspect
+import sys
+
+import sqlalchemy.exc as exceptions
+sys.modules['sqlalchemy.exceptions'] = exceptions
+
from sqlalchemy.types import \
BLOB, BOOLEAN, CHAR, CLOB, DATE, DATETIME, DECIMAL, FLOAT, INT, \
NCHAR, NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, \
if not (name.startswith('_') or inspect.ismodule(obj)) ]
__version__ = 'svn'
+
+del inspect, sys
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, schema, types, exceptions, pool
+from sqlalchemy import sql, schema, types, exc, pool
from sqlalchemy.sql import compiler, expression
from sqlalchemy.engine import default, base
except pythoncom.com_error:
pass
else:
- raise exceptions.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.")
+ raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.")
import pyodbc as module
return module
c.execute(statement, parameters)
self.context.rowcount = c.rowcount
except Exception, e:
- raise exceptions.DBAPIError.instance(statement, parameters, e)
+ raise exc.DBAPIError.instance(statement, parameters, e)
def has_table(self, connection, tablename, schema=None):
# This approach seems to be more reliable that using DAO
if tbl.Name.lower() == table.name.lower():
break
else:
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
for col in tbl.Fields:
coltype = self.ischema_names[col.Type]
# This is necessary, so we get the latest updates
dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
- names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] <> "~TMP"]
+ names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"]
dtbs.Close()
return names
if select.limit:
s += "TOP %s " % (select.limit)
if select.offset:
- raise exceptions.InvalidRequestError('Access does not support LIMIT with an offset')
+ raise exc.InvalidRequestError('Access does not support LIMIT with an offset')
return s
def limit_clause(self, select):
# Strip schema
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
- return self.preparer.quote(table, table.name)
+ return self.preparer.quote(table.name, table.quote)
else:
return ""
class AccessSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
# install a sequence if we have an implicit IDENTITY column
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
import datetime
-from sqlalchemy import exceptions, schema, types as sqltypes, sql, util
+from sqlalchemy import exc, schema, types as sqltypes, sql, util
from sqlalchemy.engine import base, default
default.DefaultDialect.__init__(self, **kwargs)
self.type_conv = type_conv
- self.concurrency_level= concurrency_level
+ self.concurrency_level = concurrency_level
def dbapi(cls):
import kinterbasdb
version = fbconn.server_version
m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
if not m:
- raise exceptions.AssertionError("Could not determine version from string '%s'" % version)
+ raise AssertionError("Could not determine version from string '%s'" % version)
return tuple([int(x) for x in m.group(5, 6, 4)])
def _normalize_name(self, name):
# get primary key fields
c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
- pkfields =[self._normalize_name(r['fname']) for r in c.fetchall()]
+ pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()]
# get all of the fields for this table
c = connection.execute(tblqry, [tablename])
table.append_column(col)
if not found_table:
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
# get the foreign keys
c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
fks = {}
while True:
row = c.fetchone()
- if not row: break
+ if not row:
+ break
cname = self._normalize_name(row['cname'])
try:
fk[0].append(fname)
fk[1].append(refspec)
- for name,value in fks.iteritems():
+ for name, value in fks.iteritems():
table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
def do_execute(self, cursor, statement, parameters, **kwargs):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
- colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
reserved_words = RESERVED_WORDS
def __init__(self, dialect):
- super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True)
+ super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
dialect = FBDialect
import sqlalchemy.sql as sql
-import sqlalchemy.exceptions as exceptions
+import sqlalchemy.exc as exc
from sqlalchemy import select, MetaData, Table, Column, String, Integer
from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint
coltype = ischema_names[type]
#print "coltype " + repr(coltype) + " args " + repr(args)
coltype = coltype(*args)
- colargs= []
+ colargs = []
if default is not None:
colargs.append(PassiveDefault(sql.text(default)))
table.append_column(Column(name, coltype, nullable=nullable, *colargs))
if not found_table:
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
# we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns
# in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys
row[colmap[6]]
)
#print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column)
- if type=='PRIMARY KEY':
+ if type == 'PRIMARY KEY':
table.primary_key.add(table.c[constrained_column])
- elif type=='FOREIGN KEY':
+ elif type == 'FOREIGN KEY':
try:
fk = fks[constraint_name]
except KeyError:
- fk = ([],[])
+ fk = ([], [])
fks[constraint_name] = fk
if current_schema == referred_schema:
referred_schema = table.schema
import datetime
-from sqlalchemy import sql, schema, exceptions, pool, util
+from sqlalchemy import sql, schema, exc, pool, util
from sqlalchemy.sql import compiler
from sqlalchemy.engine import default
from sqlalchemy import types as sqltypes
# 5 - rowid after insert
def post_exec(self):
if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None:
- self._last_inserted_ids = [self.cursor.sqlerrd[1],]
+ self._last_inserted_ids = [self.cursor.sqlerrd[1]]
elif hasattr( self.compiled , 'offset' ):
self.cursor.offset( self.compiled.offset )
super(InfoExecutionContext, self).post_exec()
# for informix 7.31
max_identifier_length = 18
- def __init__(self, use_ansi=True,**kwargs):
+ def __init__(self, use_ansi=True, **kwargs):
self.use_ansi = use_ansi
default.DefaultDialect.__init__(self, **kwargs)
else:
opt = {}
- return ([dsn,], opt )
+ return ([dsn], opt)
def create_execution_context(self , *args, **kwargs):
return InfoExecutionContext(self, *args, **kwargs)
- def oid_column_name(self,column):
+ def oid_column_name(self, column):
return "rowid"
def table_names(self, connection, schema):
s = "select tabname from systables"
return [row[0] for row in connection.execute(s)]
- def has_table(self, connection, table_name,schema=None):
+ def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() )
return bool( cursor.fetchone() is not None )
c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() )
rows = c.fetchall()
if not rows :
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
else:
if table.owner is not None:
if table.owner.lower() in [r[0] for r in rows]:
owner = table.owner.lower()
else:
- raise exceptions.AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name))
+ raise AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name))
else:
if len(rows)==1:
owner = rows[0][0]
else:
- raise exceptions.AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)
+ raise AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)
c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3
where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=?
rows = c.fetchall()
if not rows:
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
for name , colattr , collength , default , colno in rows:
name = name.lower()
try:
fk = fks[cons_name]
except KeyError:
- fk = ([], [])
- fks[cons_name] = fk
+ fk = ([], [])
+ fks[cons_name] = fk
refspec = ".".join([remote_table, remote_column])
schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection)
if local_column not in fk[0]:
colspec += " SERIAL"
self.has_serial = True
else:
- colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
import datetime, itertools, re
-from sqlalchemy import exceptions, schema, sql, util
+from sqlalchemy import exc, schema, sql, util
from sqlalchemy.sql import operators as sql_operators, expression as sql_expr
from sqlalchemy.sql import compiler, visitors
from sqlalchemy.engine import base as engine_base, default
ms = getattr(value, 'microsecond', 0)
return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms))
else:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
value[11:13], value[14:16], value[17:19],
value[20:])])
else:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
elif dialect.datetimeformat == 'iso':
return value.strftime("%Y-%m-%d")
else:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
return datetime.date(
*[int(v) for v in (value[0:4], value[5:7], value[8:10])])
else:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
elif dialect.datetimeformat == 'iso':
return value.strftime("%H-%M-%S")
else:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
return datetime.time(
*[int(v) for v in (value[0:4], value[5:7], value[8:10])])
else:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
rows = connection.execute(st, params).fetchall()
if not rows:
- raise exceptions.NoSuchTableError(table.fullname)
+ raise exc.NoSuchTableError(table.fullname)
include_columns = util.Set(include_columns or [])
# LIMIT. Right? Other dialects seem to get away with
# dropping order.
if select._limit:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"MaxDB does not support ORDER BY in subqueries")
else:
return ""
sql = select._distinct and 'DISTINCT ' or ''
if self.is_subquery(select) and select._limit:
if select._offset:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
'MaxDB does not support LIMIT with an offset.')
sql += 'TOP %s ' % select._limit
return sql
# sub queries need TOP
return ''
elif select._offset:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
'MaxDB does not support LIMIT with an offset.')
else:
return ' \n LIMIT %s' % (select._limit,)
class MaxDBSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kw):
colspec = [self.preparer.format_column(column),
- column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()]
+ column.type.dialect_impl(self.dialect).get_col_spec()]
if not column.nullable:
colspec.append('NOT NULL')
import datetime, operator, re, sys
-from sqlalchemy import sql, schema, exceptions, util
+from sqlalchemy import sql, schema, exc, util
from sqlalchemy.sql import compiler, expression, operators as sqlops, functions as sql_functions
from sqlalchemy.engine import default, base
from sqlalchemy import types as sqltypes
dialect_cls = dialect_mapping[module_name]
return dialect_cls.import_dbapi()
except KeyError:
- raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
+ raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
else:
for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]:
try:
self.context.rowcount = c.rowcount
c.DBPROP_COMMITPRESERVE = "Y"
except Exception, e:
- raise exceptions.DBAPIError.instance(statement, parameters, e)
+ raise exc.DBAPIError.instance(statement, parameters, e)
def table_names(self, connection, schema):
from sqlalchemy.databases import information_schema as ischema
elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1:
args[0] = None
coltype = coltype(*args)
- colargs= []
+ colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))
if not found_table:
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
# We also run an sp_columns to check for identity columns:
cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema))
row = cursor.fetchone()
cursor.close()
if not row is None:
- ic.sequence.start=int(row[0])
- ic.sequence.increment=int(row[1])
+ ic.sequence.start = int(row[0])
+ ic.sequence.increment = int(row[1])
except:
# ignoring it, works just like before
pass
if rfknm != fknm:
if fknm:
- table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table,s,t,c) for s,t,c in rcols], fknm))
+ table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm))
fknm, scols, rcols = (rfknm, [], [])
- if (not scol in scols): scols.append(scol)
- if (not (rschema, rtbl, rcol) in rcols): rcols.append((rschema, rtbl, rcol))
+ if not scol in scols:
+ scols.append(scol)
+ if not (rschema, rtbl, rcol) in rcols:
+ rcols.append((rschema, rtbl, rcol))
if fknm and scols:
- table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table,s,t,c) for s,t,c in rcols], fknm))
+ table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm))
class MSSQLDialect_pymssql(MSSQLDialect):
if select._limit:
s += "TOP %s " % (select._limit,)
if select._offset:
- raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
+ raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
return compiler.DefaultCompiler.get_select_precolumns(self, select)
class MSSQLSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
# override 'connect' call
def connect(*args, **kwargs):
- import mx.ODBC.Windows
- conn = mx.ODBC.Windows.Connect(*args, **kwargs)
- conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT
- return Connection(conn)
+ import mx.ODBC.Windows
+ conn = mx.ODBC.Windows.Connect(*args, **kwargs)
+ conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT
+ return Connection(conn)
Connect = connect
import datetime, inspect, re, sys
from array import array as _array
-from sqlalchemy import exceptions, logging, schema, sql, util
+from sqlalchemy import exc, log, schema, sql, util
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy.sql import functions as sql_functions
from sqlalchemy.sql import compiler
if ((precision is None and length is not None) or
(precision is not None and length is None)):
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"You must specify both precision and length or omit "
"both altogether.")
super_convert = super(MSEnum, self).bind_processor(dialect)
def process(value):
if self.strict and value is not None and value not in self.enums:
- raise exceptions.InvalidRequestError('"%s" not a valid value for '
+ raise exc.InvalidRequestError('"%s" not a valid value for '
'this enum' % value)
if super_convert:
return super_convert(value)
have = rs.rowcount > 0
rs.close()
return have
- except exceptions.SQLError, e:
+ except exc.SQLError, e:
if e.orig.args[0] == 1146:
return False
raise
try:
try:
rp = connection.execute(st)
- except exceptions.SQLError, e:
+ except exc.SQLError, e:
if e.orig.args[0] == 1146:
- raise exceptions.NoSuchTableError(full_name)
+ raise exc.NoSuchTableError(full_name)
else:
raise
row = _compat_fetchone(rp, charset=charset)
if not row:
- raise exceptions.NoSuchTableError(full_name)
+ raise exc.NoSuchTableError(full_name)
return row[1].strip()
finally:
if rp:
try:
try:
rp = connection.execute(st)
- except exceptions.SQLError, e:
+ except exc.SQLError, e:
if e.orig.args[0] == 1146:
- raise exceptions.NoSuchTableError(full_name)
+ raise exc.NoSuchTableError(full_name)
else:
raise
rows = _compat_fetchall(rp, charset=charset)
def for_update_clause(self, select):
if select.for_update == 'read':
- return ' LOCK IN SHARE MODE'
+ return ' LOCK IN SHARE MODE'
else:
return super(MySQLCompiler, self).for_update_clause(select)
"""Builds column DDL."""
colspec = [self.preparer.format_column(column),
- column.type.dialect_impl(self.dialect,
- _for_ddl=column).get_col_spec()]
+ column.type.dialect_impl(self.dialect).get_col_spec()]
default = self.get_column_default_string(column)
if default is not None:
ref_names = spec['foreign']
if not util.Set(ref_names).issubset(
util.Set([c.name for c in ref_table.c])):
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"Foreign key columns (%s) are not present on "
"foreign table %s" %
(', '.join(ref_names), ref_table.fullname()))
return self._re_keyexprs.findall(identifiers)
-MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector)
+MySQLSchemaReflector.logger = log.class_logger(MySQLSchemaReflector)
class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
import datetime, random, re
-from sqlalchemy import util, sql, schema, exceptions, logging
+from sqlalchemy import util, sql, schema, log
from sqlalchemy.engine import default, base
from sqlalchemy.sql import compiler, visitors
from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
def result_processor(self, dialect):
def process(value):
- if value is None or isinstance(value,datetime.datetime):
+ if value is None or isinstance(value, datetime.datetime):
return value
else:
# convert cx_oracle datetime object returned pre-python 2.4
- return datetime.datetime(value.year,value.month,
+ return datetime.datetime(value.year, value.month,
value.day,value.hour, value.minute, value.second)
return process
def result_processor(self, dialect):
def process(value):
- if value is None or isinstance(value,datetime.datetime):
+ if value is None or isinstance(value, datetime.datetime):
return value
else:
# convert cx_oracle datetime object returned pre-python 2.4
- return datetime.datetime(value.year,value.month,
+ return datetime.datetime(value.year, value.month,
value.day,value.hour, value.minute, value.second)
return process
def get_result_proxy(self):
if hasattr(self, 'out_parameters'):
if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
- for bind, name in self.compiled.bind_names.iteritems():
- if name in self.out_parameters:
- type = bind.type
- self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue())
+ for bind, name in self.compiled.bind_names.iteritems():
+ if name in self.out_parameters:
+ type = bind.type
+ self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue())
else:
- for k in self.out_parameters:
- self.out_parameters[k] = self.out_parameters[k].getvalue()
+ for k in self.out_parameters:
+ self.out_parameters[k] = self.out_parameters[k].getvalue()
if self.cursor.description is not None:
for column in self.cursor.description:
this id will be passed to do_begin_twophase(), do_rollback_twophase(),
do_commit_twophase(). its format is unspecified."""
- id = random.randint(0,2**128)
+ id = random.randint(0, 2 ** 128)
return (0x1234, "%032x" % 9, "%032x" % id)
def do_release_savepoint(self, connection, name):
cursor = connection.execute(s)
else:
s = "select table_name from all_tables where tablespace_name NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner"
- cursor = connection.execute(s,{'owner':self._denormalize_name(schema)})
+ cursor = connection.execute(s, {'owner': self._denormalize_name(schema)})
return [self._normalize_name(row[0]) for row in cursor]
def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
if desired_owner is None, attempts to locate a distinct owner.
- returns the actual name, owner, dblink name, and synonym name if found.
+ returns the actual name, owner, dblink name, and synonym name if found.
"""
- sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME
- from ALL_SYNONYMS WHERE """
+ sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME
+ from ALL_SYNONYMS WHERE """
clauses = []
params = {}
clauses.append("TABLE_NAME=:tname")
params['tname'] = desired_table
- sql += " AND ".join(clauses)
+ sql += " AND ".join(clauses)
- result = connection.execute(sql, **params)
+ result = connection.execute(sql, **params)
if desired_owner:
row = result.fetchone()
if row:
else:
rows = result.fetchall()
if len(rows) > 1:
- raise exceptions.AssertionError("There are multiple tables visible to the schema, you must specify owner")
+ raise AssertionError("There are multiple tables visible to the schema, you must specify owner")
elif len(rows) == 1:
row = rows[0]
return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
- if resolve_synonyms:
+ if resolve_synonyms:
actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name))
else:
actual_name, owner, dblink, synonym = None, None, None, None
# NUMBER(9,2) if the precision is 9 and the scale is 2
# NUMBER(3) if the precision is 3 and scale is 0
#length is ignored except for CHAR and VARCHAR2
- if coltype=='NUMBER' :
+ if coltype == 'NUMBER' :
if precision is None and scale is None:
coltype = OracleNumeric
elif precision is None and scale == 0 :
table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
if not table.columns:
- raise exceptions.AssertionError("Couldn't find any column information for table %s" % actual_name)
+ raise AssertionError("Couldn't find any column information for table %s" % actual_name)
c = connection.execute("""SELECT
ac.constraint_name,
try:
fk = fks[cons_name]
except KeyError:
- fk = ([], [])
- fks[cons_name] = fk
+ fk = ([], [])
+ fks[cons_name] = fk
if remote_table is None:
# ticket 363
util.warn(
remote_owner = self._normalize_name(ref_remote_owner)
if not table.schema and self._denormalize_name(remote_owner) == owner:
- refspec = ".".join([remote_table, remote_column])
+ refspec = ".".join([remote_table, remote_column])
t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
else:
refspec = ".".join([x for x in [remote_owner, remote_table, remote_column] if x])
table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
-OracleDialect.logger = logging.class_logger(OracleDialect)
+OracleDialect.logger = log.class_logger(OracleDialect)
class _OuterJoinColumn(sql.ClauseElement):
__visit_name__ = 'outer_join_column'
self.column = column
def _get_from_objects(self, **kwargs):
return []
-
+
class OracleCompiler(compiler.DefaultCompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
else:
return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
-
+
def _get_nonansi_join_whereclause(self, froms):
clauses = []
-
+
def visit_join(join):
if join.isouter:
def visit_binary(binary):
binary.left = _OuterJoinColumn(binary.left)
elif binary.right.table is join.right:
binary.right = _OuterJoinColumn(binary.right)
- clauses.append(visitors.traverse(join.onclause, visit_binary=visit_binary, clone=True))
+ clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
else:
clauses.append(join.onclause)
-
+
for f in froms:
- visitors.traverse(f, visit_join=visit_join)
+ visitors.traverse(f, {}, {'join':visit_join})
return sql.and_(*clauses)
-
+
def visit_outer_join_column(self, vc):
return self.process(vc.column) + "(+)"
if whereclause:
select = select.where(whereclause)
select._oracle_visit = True
-
+
if select._limit is not None or select._offset is not None:
# to use ROW_NUMBER(), an ORDER BY is required.
orderby = self.process(select._order_by_clause)
select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
select._oracle_visit = True
-
+
limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
limitselect._oracle_visit = True
limitselect._is_wrapper = True
-
+
if select._offset is not None:
limitselect.append_whereclause("ora_rn>%d" % select._offset)
if select._limit is not None:
else:
limitselect.append_whereclause("ora_rn<=%d" % select._limit)
select = limitselect
-
+
kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
return ""
def for_update_clause(self, select):
- if select.for_update=="nowait":
+ if select.for_update == "nowait":
return " FOR UPDATE NOWAIT"
else:
return super(OracleCompiler, self).for_update_clause(select)
class OracleSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
- colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
import random, re, string
-from sqlalchemy import sql, schema, exceptions, util
+from sqlalchemy import sql, schema, exc, util
from sqlalchemy.engine import base, default
from sqlalchemy.sql import compiler, expression
from sqlalchemy.sql import operators as sql_operators
class PGString(sqltypes.String):
def get_col_spec(self):
- return "VARCHAR(%(length)s)" % {'length' : self.length}
+ if self.length:
+ return "VARCHAR(%(length)d)" % {'length' : self.length}
+ else:
+ return "VARCHAR"
class PGChar(sqltypes.CHAR):
def get_col_spec(self):
- return "CHAR(%(length)s)" % {'length' : self.length}
+ if self.length:
+ return "CHAR(%(length)d)" % {'length' : self.length}
+ else:
+ return "CHAR"
class PGBinary(sqltypes.Binary):
def get_col_spec(self):
if value is None:
return value
def convert_item(item):
- if isinstance(item, (list,tuple)):
+ if isinstance(item, (list, tuple)):
return [convert_item(child) for child in item]
else:
if item_proc:
def last_inserted_ids(self):
if self.context.last_inserted_ids is None:
- raise exceptions.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled")
+ raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled")
else:
return self.context.last_inserted_ids
v = connection.execute("select version()").scalar()
m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v)
if not m:
- raise exceptions.AssertionError("Could not determine version from string '%s'" % v)
+ raise AssertionError("Could not determine version from string '%s'" % v)
return tuple([int(x) for x in m.group(1, 2, 3)])
def reflecttable(self, connection, table, include_columns):
rows = c.fetchall()
if not rows:
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
domains = self._load_domains(connection)
default = domain['default']
coltype = ischema_names[domain['attype']]
else:
- coltype=None
+ coltype = None
if coltype:
coltype = coltype(*args, **kwargs)
(attype, name))
coltype = sqltypes.NULLTYPE
- colargs= []
+ colargs = []
if default is not None:
match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
if match is not None:
col = table.c[pk]
table.primary_key.add(col)
if col.default is None:
- col.autoincrement=False
+ col.autoincrement = False
# Foreign keys
FK_SQL = """
yield co
else:
yield c
- columns = [self.process(c) for c in flatten_columnlist(returning_cols)]
+ columns = [self.process(c, render_labels=True) for c in flatten_columnlist(returning_cols)]
text += ' RETURNING ' + string.join(columns, ', ')
return text
else:
colspec += " SERIAL"
else:
- colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
import datetime, re, time
-from sqlalchemy import schema, exceptions, pool, PassiveDefault
+from sqlalchemy import schema, exc, pool, PassiveDefault
from sqlalchemy.engine import default
import sqlalchemy.types as sqltypes
import sqlalchemy.util as util
microsecond = 0
return time.strptime(value, self.__format__)[0:6] + (microsecond,)
-class SLDateTime(DateTimeMixin,sqltypes.DateTime):
+class SLDateTime(DateTimeMixin, sqltypes.DateTime):
__format__ = "%Y-%m-%d %H:%M:%S"
__microsecond__ = True
class SLString(sqltypes.String):
def get_col_spec(self):
- return "VARCHAR(%(length)s)" % {'length' : self.length}
+ return "VARCHAR" + (self.length and "(%d)" % self.length or "")
class SLChar(sqltypes.CHAR):
def get_col_spec(self):
- return "CHAR(%(length)s)" % {'length' : self.length}
+ return "CHAR" + (self.length and "(%d)" % self.length or "")
class SLBinary(sqltypes.Binary):
def get_col_spec(self):
return tuple([int(x) for x in num.split('.')])
if self.dbapi is not None:
sqlite_ver = self.dbapi.version_info
- if sqlite_ver < (2,1,'3'):
+ if sqlite_ver < (2, 1, '3'):
util.warn(
("The installed version of pysqlite2 (%s) is out-dated "
"and will cause errors in some cases. Version 2.1.3 "
def create_connect_args(self, url):
if url.username or url.password or url.host or url.port:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Invalid SQLite URL: %s\n"
"Valid SQLite URL forms are:\n"
" sqlite:///:memory: (or, sqlite://)\n"
" SELECT * FROM sqlite_temp_master) "
"WHERE type='table' ORDER BY name")
rs = connection.execute(s)
- except exceptions.DBAPIError:
+ except exc.DBAPIError:
raise
s = ("SELECT name FROM sqlite_master "
"WHERE type='table' ORDER BY name")
args = re.findall(r'(\d+)', args)
coltype = coltype(*[int(a) for a in args])
- colargs= []
+ colargs = []
if has_default:
colargs.append(PassiveDefault('?'))
table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
if not found_table:
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))
fks = {}
try:
fk = fks[constraint_name]
except KeyError:
- fk = ([],[])
+ fk = ([], [])
fks[constraint_name] = fk
# look up the table based on the given table's engine, not 'self',
class SQLiteSchemaGenerator(compiler.SchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
import datetime, operator
-from sqlalchemy import util, sql, schema, exceptions
+from sqlalchemy import util, sql, schema, exc
from sqlalchemy.sql import compiler, expression
from sqlalchemy.engine import default, base
from sqlalchemy import types as sqltypes
def bind_processor(self, dialect):
def process(value):
- raise exceptions.NotSupportedError("Data type not supported", [value])
+ raise exc.NotSupportedError("Data type not supported", [value])
return process
def get_col_spec(self):
- raise exceptions.NotSupportedError("Data type not supported")
+ raise exc.NotSupportedError("Data type not supported")
class SybaseNumeric(sqltypes.Numeric):
def get_col_spec(self):
dialect_cls = dialect_mapping[module_name]
return dialect_cls.import_dbapi()
except KeyError:
- raise exceptions.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name)
+ raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name)
else:
for dialect_cls in dialect_mapping.values():
try:
self.context.rowcount = c.rowcount
c.DBPROP_COMMITPRESERVE = "Y"
except Exception, e:
- raise exceptions.DBAPIError.instance(statement, parameters, e)
+ raise exc.DBAPIError.instance(statement, parameters, e)
def table_names(self, connection, schema):
"""Ignore the schema and the charset for now."""
(type, name))
coltype = sqltypes.NULLTYPE
coltype = coltype(*args)
- colargs= []
+ colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
row[0], row[1], row[2], row[3],
)
if not primary_table in foreignKeys.keys():
- foreignKeys[primary_table] = [['%s'%(foreign_column)], ['%s.%s'%(primary_table,primary_column)]]
+ foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
else:
foreignKeys[primary_table][0].append('%s'%(foreign_column))
- foreignKeys[primary_table][1].append('%s.%s'%(primary_table,primary_column))
+ foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
for primary_table in foreignKeys.keys():
#table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1]))
if not found_table:
- raise exceptions.NoSuchTableError(table.name)
+ raise exc.NoSuchTableError(table.name)
class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
def bindparam_string(self, name):
res = super(SybaseSQLCompiler, self).bindparam_string(name)
if name.lower().startswith('literal'):
- res = 'STRING(%s)'%res
+ res = 'STRING(%s)' % res
return res
def get_select_precolumns(self, select):
#colspec += " numeric(30,0) IDENTITY"
colspec += " Integer IDENTITY"
else:
- colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
if not column.nullable:
colspec += " NOT NULL"
"""
import inspect, StringIO, sys
-from sqlalchemy import exceptions, schema, util, types, logging
+from sqlalchemy import exc, schema, util, types, log
from sqlalchemy.sql import expression
self.statement = statement
self.column_keys = column_keys
self.bind = bind
- self.can_execute = statement.supports_execution()
+ self.can_execute = statement.supports_execution
def compile(self):
"""Produce the internal string representation of this element."""
e = self.bind
if e is None:
- raise exceptions.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.")
+ raise exc.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.")
return e._execute_compiled(self, multiparams, params)
def scalar(self, *multiparams, **params):
self.__savepoint_seq = 0
self.__branch = _branch
self.__invalid = False
-
+
def _branch(self):
"""Return a new Connection which references this Connection's
engine and connection; but does not have close_with_result enabled,
This is used to execute "sub" statements within a single execution,
usually an INSERT statement.
"""
- return Connection(self.engine, self.__connection, _branch=True)
+ return self.engine.Connection(self.engine, self.__connection, _branch=True)
def dialect(self):
"Dialect used by this Connection."
except AttributeError:
if self.__invalid:
if self.__transaction is not None:
- raise exceptions.InvalidRequestError("Can't reconnect until invalid transaction is rolled back")
+ raise exc.InvalidRequestError("Can't reconnect until invalid transaction is rolled back")
self.__connection = self.engine.raw_connection()
self.__invalid = False
return self.__connection
- raise exceptions.InvalidRequestError("This Connection is closed")
+ raise exc.InvalidRequestError("This Connection is closed")
connection = property(connection)
def should_close_with_result(self):
"""
if self.__transaction is not None:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"Cannot start a two phase transaction when a transaction "
"is already in progress.")
if xid is None:
if c in Connection.executors:
return Connection.executors[c](self, object, multiparams, params)
else:
- raise exceptions.InvalidRequestError("Unexecutable object type: " + str(type(object)))
+ raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object)))
def _execute_default(self, default, multiparams=None, params=None):
return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
in the case of 'raw' execution which accepts positional parameters,
it may be a list of tuples or lists."""
- if multiparams is None or len(multiparams) == 0:
+ if not multiparams:
if params:
return [params]
else:
def _execute_compiled(self, compiled, multiparams=None, params=None, distilled_params=None):
"""Execute a sql.Compiled object."""
if not compiled.can_execute:
- raise exceptions.ArgumentError("Not an executable clause: %s" % (str(compiled)))
+ raise exc.ArgumentError("Not an executable clause: %s" % (str(compiled)))
if distilled_params is None:
distilled_params = self.__distill_params(multiparams, params)
def _handle_dbapi_exception(self, e, statement, parameters, cursor):
if getattr(self, '_reentrant_error', False):
- raise exceptions.DBAPIError.instance(None, None, e)
+ raise exc.DBAPIError.instance(None, None, e)
self._reentrant_error = True
try:
if not isinstance(e, self.dialect.dbapi.Error):
self._autorollback()
if self.__close_with_result:
self.close()
- raise exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
+ raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
finally:
del self._reentrant_error
def commit(self):
if not self._parent._is_active:
- raise exceptions.InvalidRequestError("This transaction is inactive")
+ raise exc.InvalidRequestError("This transaction is inactive")
self._do_commit()
self._is_active = False
def prepare(self):
if not self._parent._is_active:
- raise exceptions.InvalidRequestError("This transaction is inactive")
+ raise exc.InvalidRequestError("This transaction is inactive")
self._connection._prepare_twophase_impl(self.xid)
self._is_prepared = True
provide a default implementation of SchemaEngine.
"""
- def __init__(self, pool, dialect, url, echo=None):
+ def __init__(self, pool, dialect, url, echo=None, proxy=None):
self.pool = pool
self.url = url
- self.dialect=dialect
+ self.dialect = dialect
self.echo = echo
self.engine = self
- self.logger = logging.instance_logger(self, echoflag=echo)
+ self.logger = log.instance_logger(self, echoflag=echo)
+ if proxy:
+ self.Connection = _proxy_connection_cls(Connection, proxy)
+ else:
+ self.Connection = Connection
def name(self):
"String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``."
return sys.modules[self.dialect.__module__].descriptor()['name']
name = property(name)
- echo = logging.echo_property()
+ echo = log.echo_property()
def __repr__(self):
return 'Engine(%s)' % str(self.url)
def connect(self, **kwargs):
"""Return a newly allocated Connection object."""
- return Connection(self, **kwargs)
+ return self.Connection(self, **kwargs)
def contextual_connect(self, close_with_result=False, **kwargs):
"""Return a Connection object which may be newly allocated, or may be part of some ongoing context.
This Connection is meant to be used by the various "auto-connecting" operations.
"""
- return Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs)
+ return self.Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs)
def table_names(self, schema=None, connection=None):
"""Return a list of all table names available in the database.
return self.pool.unique_connection()
+def _proxy_connection_cls(cls, proxy):
+ class ProxyConnection(cls):
+ def execute(self, object, *multiparams, **params):
+ return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params)
+
+ def execute_clauseelement(self, elem, multiparams=None, params=None):
+ return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {}))
+
+ def _cursor_execute(self, cursor, statement, parameters, context=None):
+ return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False)
+
+ def _cursor_executemany(self, cursor, statement, parameters, context=None):
+ return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True)
+
+ return ProxyConnection
+
class RowProxy(object):
"""Proxy a single cursor row for a parent ResultProxy.
results that correspond to constructed SQL expressions).
"""
+ __slots__ = ['__parent', '__row']
+
def __init__(self, parent, row):
"""RowProxy objects are constructed by ResultProxy objects."""
return props[key._label.lower()]
elif hasattr(key, 'name') and key.name.lower() in props:
return props[key.name.lower()]
- raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
+ raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
return rec
return util.PopulateDict(lookup_key)
def __ambiguous_processor(self, colname):
def process(value):
- raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname)
+ raise exc.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname)
return process
def close(self):
"""
-
import re, random
from sqlalchemy.engine import base
from sqlalchemy.sql import compiler, expression
This id will be passed to do_begin_twophase(), do_rollback_twophase(),
do_commit_twophase(). Its format is unspecified."""
- return "_sa_%032x" % random.randint(0,2**128)
+ return "_sa_%032x" % random.randint(0, 2 ** 128)
def do_savepoint(self, connection, name):
connection.execute(expression.SavepointClause(name))
if self.dialect.positional:
inputsizes = []
for key in self.compiled.positiontup:
- typeengine = types[key]
- dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None:
+ typeengine = types[key]
+ dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+ if dbtype is not None:
inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*inputsizes)
self._last_updated_params = compiled_parameters
self.postfetch_cols = self.compiled.postfetch
- self.prefetch_cols = self.compiled.prefetch
\ No newline at end of file
+ self.prefetch_cols = self.compiled.prefetch
from sqlalchemy.engine import base, threadlocal, url
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
from sqlalchemy import pool as poollib
strategies = {}
try:
return dbapi.connect(*cargs, **cparams)
except Exception, e:
- raise exceptions.DBAPIError.instance(None, None, e)
+ raise exc.DBAPIError.instance(None, None, e)
creator = kwargs.pop('creator', connect)
poolclass = (kwargs.pop('poolclass', None) or
def create(self, entity, **kwargs):
kwargs['checkfirst'] = False
- self.dialect.schemagenerator(self.dialect ,self, **kwargs).traverse(entity)
+ self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity)
def drop(self, entity, **kwargs):
kwargs['checkfirst'] = False
try:
return self.__transaction._increment_connect()
except AttributeError:
- return TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result)
+ return self.engine.TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result)
def reset(self):
try:
class TLConnection(base.Connection):
- def __init__(self, session, connection, close_with_result):
- base.Connection.__init__(self, session.engine, connection, close_with_result=close_with_result)
+ def __init__(self, session, connection, **kwargs):
+ base.Connection.__init__(self, session.engine, connection, **kwargs)
self.__session = session
self.__opencount = 1
+ def _branch(self):
+ return self.engine.Connection(self.engine, self.connection, _branch=True)
+
def session(self):
return self.__session
session = property(session)
super(TLEngine, self).__init__(*args, **kwargs)
self.context = util.ThreadLocal()
+ proxy = kwargs.get('proxy')
+ if proxy:
+ self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
+ else:
+ self.TLConnection = TLConnection
+
def session(self):
"Returns the current thread's TLSession"
if not hasattr(self.context, 'session'):
"""
import re, cgi, sys, urllib
-from sqlalchemy import exceptions
+from sqlalchemy import exc
class URL(object):
self.port = int(port)
else:
self.port = None
- self.database= database
+ self.database = database
self.query = query or {}
def __str__(self):
name = components.pop('name')
return URL(name, **components)
else:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Could not parse rfc1738 URL from string '%s'" % name)
def _parse_keyvalue_args(name):
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
"""Exceptions used with SQLAlchemy.
-The base exception class is SQLAlchemyError. Exceptions which are raised as a result
-of DBAPI exceptions are all subclasses of [sqlalchemy.exceptions#DBAPIError]."""
+The base exception class is SQLAlchemyError. Exceptions which are raised as a
+result of DBAPI exceptions are all subclasses of
+[sqlalchemy.exceptions#DBAPIError].
+
+"""
+
class SQLAlchemyError(Exception):
"""Generic error class."""
class ArgumentError(SQLAlchemyError):
- """Raised for all those conditions where invalid arguments are
- sent to constructed objects. This error generally corresponds to
- construction time state errors.
+ """Raised when an invalid or conflicting function argument is supplied.
+
+ This error generally corresponds to construction time state errors.
+
"""
+class CircularDependencyError(SQLAlchemyError):
+ """Raised by topological sorts when a circular dependency is detected"""
+
+
class CompileError(SQLAlchemyError):
"""Raised when an error occurs during SQL compilation"""
-class TimeoutError(SQLAlchemyError):
- """Raised when a connection pool times out on getting a connection."""
+# Moved to orm.exc; compatability definition installed by orm import until 0.6
+ConcurrentModificationError = None
+class DisconnectionError(SQLAlchemyError):
+ """A disconnect is detected on a raw DB-API connection.
-class ConcurrentModificationError(SQLAlchemyError):
- """Raised when a concurrent modification condition is detected."""
+ This error is raised and consumed internally by a connection pool. It can
+ be raised by a ``PoolListener`` so that the host pool forces a disconnect.
+ """
-class CircularDependencyError(SQLAlchemyError):
- """Raised by topological sorts when a circular dependency is detected"""
+# Moved to orm.exc; compatability definition installed by orm import until 0.6
+FlushError = None
-class FlushError(SQLAlchemyError):
- """Raised when an invalid condition is detected upon a ``flush()``."""
+class TimeoutError(SQLAlchemyError):
+ """Raised when a connection pool times out on getting a connection."""
class InvalidRequestError(SQLAlchemyError):
- """SQLAlchemy was asked to do something it can't do, return
- nonexistent data, etc.
+ """SQLAlchemy was asked to do something it can't do.
This error generally corresponds to runtime state errors.
- """
-
-class UnmappedColumnError(InvalidRequestError):
- """A mapper was asked to return mapped information about a column
- which it does not map"""
-class NoSuchTableError(InvalidRequestError):
- """SQLAlchemy was asked to load a table's definition from the
- database, but the table doesn't exist.
"""
-class UnboundExecutionError(InvalidRequestError):
- """SQL was attempted without a database connection to execute it on."""
+class NoSuchColumnError(KeyError, InvalidRequestError):
+ """A nonexistent column is requested from a ``RowProxy``."""
-class AssertionError(SQLAlchemyError):
- """Corresponds to internal state being detected in an invalid state."""
-
-
-class NoSuchColumnError(KeyError, SQLAlchemyError):
- """Raised by ``RowProxy`` when a nonexistent column is requested from a row."""
-
class NoReferencedTableError(InvalidRequestError):
"""Raised by ``ForeignKey`` when the referred ``Table`` cannot be located."""
-class DisconnectionError(SQLAlchemyError):
- """Raised within ``Pool`` when a disconnect is detected on a raw DB-API connection.
+class NoSuchTableError(InvalidRequestError):
+ """Table does not exist or is not visible to a connection."""
- This error is consumed internally by a connection pool. It can be raised by
- a ``PoolListener`` so that the host pool forces a disconnect.
- """
+class UnboundExecutionError(InvalidRequestError):
+ """SQL was attempted without a database connection to execute it on."""
+
+
+# Moved to orm.exc; compatability definition installed by orm import until 0.6
+UnmappedColumnError = None
class DBAPIError(SQLAlchemyError):
"""Raised when the execution of a database operation fails.
The wrapped exception object is available in the ``orig`` attribute.
Its type and properties are DB-API implementation specific.
+
"""
def instance(cls, statement, params, orig, connection_invalidated=False):
except Exception, e:
text = 'Error in str() of DB-API-generated exception: ' + str(e)
SQLAlchemyError.__init__(
- self, "(%s) %s" % (orig.__class__.__name__, text))
+ self, '(%s) %s' % (orig.__class__.__name__, text))
self.statement = statement
self.params = params
self.orig = orig
repr(self.statement), repr(self.params)])
-# As of 0.4, SQLError is now DBAPIError
+# As of 0.4, SQLError is now DBAPIError.
+# SQLError alias will be removed in 0.6.
SQLError = DBAPIError
class InterfaceError(DBAPIError):
"""Wraps a DB-API InterfaceError."""
+
class DatabaseError(DBAPIError):
"""Wraps a DB-API DatabaseError."""
+
class DataError(DatabaseError):
"""Wraps a DB-API DataError."""
+
class OperationalError(DatabaseError):
"""Wraps a DB-API OperationalError."""
+
class IntegrityError(DatabaseError):
"""Wraps a DB-API IntegrityError."""
+
class InternalError(DatabaseError):
"""Wraps a DB-API InternalError."""
+
class ProgrammingError(DatabaseError):
"""Wraps a DB-API ProgrammingError."""
+
class NotSupportedError(DatabaseError):
"""Wraps a DB-API NotSupportedError."""
+
# Warnings
+
class SADeprecationWarning(DeprecationWarning):
"""Issued once per usage of a deprecated API."""
+
class SAPendingDeprecationWarning(PendingDeprecationWarning):
"""Issued once per usage of a deprecated API."""
+
class SAWarning(RuntimeWarning):
"""Issued at runtime."""
+++ /dev/null
-from sqlalchemy import ThreadLocalMetaData, util, Integer
-from sqlalchemy import Table, Column, ForeignKey
-from sqlalchemy.orm import class_mapper, relation, scoped_session
-from sqlalchemy.orm import sessionmaker
-
-from sqlalchemy.orm import backref as create_backref
-
-import inspect
-import sys
-
-#
-# the "proxy" to the database engine... this can be swapped out at runtime
-#
-metadata = ThreadLocalMetaData()
-Objectstore = scoped_session
-objectstore = scoped_session(sessionmaker(autoflush=True, transactional=False))
-
-#
-# declarative column declaration - this is so that we can infer the colname
-#
-class column(object):
- def __init__(self, coltype, colname=None, foreign_key=None,
- primary_key=False, *args, **kwargs):
- if isinstance(foreign_key, basestring):
- foreign_key = ForeignKey(foreign_key)
-
- self.coltype = coltype
- self.colname = colname
- self.foreign_key = foreign_key
- self.primary_key = primary_key
- self.kwargs = kwargs
- self.args = args
-
-#
-# declarative relationship declaration
-#
-class relationship(object):
- def __init__(self, classname, colname=None, backref=None, private=False,
- lazy=True, uselist=True, secondary=None, order_by=False, viewonly=False):
- self.classname = classname
- self.colname = colname
- self.backref = backref
- self.private = private
- self.lazy = lazy
- self.uselist = uselist
- self.secondary = secondary
- self.order_by = order_by
- self.viewonly = viewonly
-
- def process(self, klass, propname, relations):
- relclass = ActiveMapperMeta.classes[self.classname]
-
- if isinstance(self.order_by, str):
- self.order_by = [ self.order_by ]
-
- if isinstance(self.order_by, list):
- for itemno in range(len(self.order_by)):
- if isinstance(self.order_by[itemno], str):
- self.order_by[itemno] = \
- getattr(relclass.c, self.order_by[itemno])
-
- backref = self.create_backref(klass)
- relations[propname] = relation(relclass.mapper,
- secondary=self.secondary,
- backref=backref,
- private=self.private,
- lazy=self.lazy,
- uselist=self.uselist,
- order_by=self.order_by,
- viewonly=self.viewonly)
-
- def create_backref(self, klass):
- if self.backref is None:
- return None
-
- relclass = ActiveMapperMeta.classes[self.classname]
-
- if klass.__name__ == self.classname:
- class_mapper(relclass).compile()
- br_fkey = relclass.c[self.colname]
- else:
- br_fkey = None
-
- return create_backref(self.backref, remote_side=br_fkey)
-
-
-class one_to_many(relationship):
- def __init__(self, *args, **kwargs):
- kwargs['uselist'] = True
- relationship.__init__(self, *args, **kwargs)
-
-class one_to_one(relationship):
- def __init__(self, *args, **kwargs):
- kwargs['uselist'] = False
- relationship.__init__(self, *args, **kwargs)
-
- def create_backref(self, klass):
- if self.backref is None:
- return None
-
- relclass = ActiveMapperMeta.classes[self.classname]
-
- if klass.__name__ == self.classname:
- br_fkey = getattr(relclass.c, self.colname)
- else:
- br_fkey = None
-
- return create_backref(self.backref, foreignkey=br_fkey, uselist=False)
-
-
-class many_to_many(relationship):
- def __init__(self, classname, secondary, backref=None, lazy=True,
- order_by=False):
- relationship.__init__(self, classname, None, backref, False, lazy,
- uselist=True, secondary=secondary,
- order_by=order_by)
-
-
-#
-# SQLAlchemy metaclass and superclass that can be used to do SQLAlchemy
-# mapping in a declarative way, along with a function to process the
-# relationships between dependent objects as they come in, without blowing
-# up if the classes aren't specified in a proper order
-#
-
-__deferred_classes__ = {}
-__processed_classes__ = {}
-def process_relationships(klass, was_deferred=False):
- # first, we loop through all of the relationships defined on the
- # class, and make sure that the related class already has been
- # completely processed and defer processing if it has not
- defer = False
- for propname, reldesc in klass.relations.items():
- found = (reldesc.classname == klass.__name__ or reldesc.classname in __processed_classes__)
- if not found:
- defer = True
- break
-
- # next, we loop through all the columns looking for foreign keys
- # and make sure that we can find the related tables (they do not
- # have to be processed yet, just defined), and we defer if we are
- # not able to find any of the related tables
- if not defer:
- for col in klass.columns:
- if col.foreign_keys:
- found = False
- cn = col.foreign_keys[0]._colspec
- table_name = cn[:cn.rindex('.')]
- for other_klass in ActiveMapperMeta.classes.values():
- if other_klass.table.fullname.lower() == table_name.lower():
- found = True
-
- if not found:
- defer = True
- break
-
- if defer and not was_deferred:
- __deferred_classes__[klass.__name__] = klass
-
- # if we are able to find all related and referred to tables, then
- # we can go ahead and assign the relationships to the class
- if not defer:
- relations = {}
- for propname, reldesc in klass.relations.items():
- reldesc.process(klass, propname, relations)
-
- class_mapper(klass).add_properties(relations)
- if klass.__name__ in __deferred_classes__:
- del __deferred_classes__[klass.__name__]
- __processed_classes__[klass.__name__] = klass
-
- # finally, loop through the deferred classes and attempt to process
- # relationships for them
- if not was_deferred:
- # loop through the list of deferred classes, processing the
- # relationships, until we can make no more progress
- last_count = len(__deferred_classes__) + 1
- while last_count > len(__deferred_classes__):
- last_count = len(__deferred_classes__)
- deferred = __deferred_classes__.copy()
- for deferred_class in deferred.values():
- process_relationships(deferred_class, was_deferred=True)
-
-
-class ActiveMapperMeta(type):
- classes = {}
- metadatas = util.Set()
- def __init__(cls, clsname, bases, dict):
- table_name = clsname.lower()
- columns = []
- relations = {}
- autoload = False
- _metadata = getattr(sys.modules[cls.__module__],
- "__metadata__", metadata)
- version_id_col = None
- version_id_col_object = None
- table_opts = {}
-
- if 'mapping' in dict:
- found_pk = False
-
- members = inspect.getmembers(dict.get('mapping'))
- for name, value in members:
- if name == '__table__':
- table_name = value
- continue
-
- if '__metadata__' == name:
- _metadata= value
- continue
-
- if '__autoload__' == name:
- autoload = True
- continue
-
- if '__version_id_col__' == name:
- version_id_col = value
-
- if '__table_opts__' == name:
- table_opts = value
-
- if name.startswith('__'): continue
-
- if isinstance(value, column):
- if value.primary_key == True: found_pk = True
-
- if value.foreign_key:
- col = Column(value.colname or name,
- value.coltype,
- value.foreign_key,
- primary_key=value.primary_key,
- *value.args, **value.kwargs)
- else:
- col = Column(value.colname or name,
- value.coltype,
- primary_key=value.primary_key,
- *value.args, **value.kwargs)
- columns.append(col)
- continue
-
- if isinstance(value, relationship):
- relations[name] = value
-
- if not found_pk and not autoload:
- col = Column('id', Integer, primary_key=True)
- cls.mapping.id = col
- columns.append(col)
-
- assert _metadata is not None, "No MetaData specified"
-
- ActiveMapperMeta.metadatas.add(_metadata)
-
- if not autoload:
- cls.table = Table(table_name, _metadata, *columns, **table_opts)
- cls.columns = columns
- else:
- cls.table = Table(table_name, _metadata, autoload=True, **table_opts)
- cls.columns = cls.table._columns
-
- if version_id_col is not None:
- version_id_col_object = getattr(cls.table.c, version_id_col, None)
- assert(version_id_col_object is not None, "version_id_col (%s) does not exist." % version_id_col)
-
- # check for inheritence
- if hasattr(bases[0], "mapping"):
- cls._base_mapper= bases[0].mapper
- cls.mapper = objectstore.mapper(cls, cls.table,
- inherits=cls._base_mapper, version_id_col=version_id_col_object)
- else:
- cls.mapper = objectstore.mapper(cls, cls.table, version_id_col=version_id_col_object)
- cls.relations = relations
- ActiveMapperMeta.classes[clsname] = cls
-
- process_relationships(cls)
-
- super(ActiveMapperMeta, cls).__init__(clsname, bases, dict)
-
-
-
-class ActiveMapper(object):
- __metaclass__ = ActiveMapperMeta
-
- def set(self, **kwargs):
- for key, value in kwargs.items():
- setattr(self, key, value)
-
-
-#
-# a utility function to create all tables for all ActiveMapper classes
-#
-
-def create_tables():
- for metadata in ActiveMapperMeta.metadatas:
- metadata.create_all()
-
-def drop_tables():
- for metadata in ActiveMapperMeta.metadatas:
- metadata.drop_all()
+++ /dev/null
-from sqlalchemy import util, exceptions
-import types
-from sqlalchemy.orm import mapper, Query
-
-def _monkeypatch_query_method(name, ctx, class_):
- def do(self, *args, **kwargs):
- query = Query(class_, session=ctx.current)
- util.warn_deprecated('Query methods on the class are deprecated; use %s.query.%s instead' % (class_.__name__, name))
- return getattr(query, name)(*args, **kwargs)
- try:
- do.__name__ = name
- except:
- pass
- if not hasattr(class_, name):
- setattr(class_, name, classmethod(do))
-
-def _monkeypatch_session_method(name, ctx, class_):
- def do(self, *args, **kwargs):
- session = ctx.current
- return getattr(session, name)(self, *args, **kwargs)
- try:
- do.__name__ = name
- except:
- pass
- if not hasattr(class_, name):
- setattr(class_, name, do)
-
-def assign_mapper(ctx, class_, *args, **kwargs):
- extension = kwargs.pop('extension', None)
- if extension is not None:
- extension = util.to_list(extension)
- extension.append(ctx.mapper_extension)
- else:
- extension = ctx.mapper_extension
-
- validate = kwargs.pop('validate', False)
-
- if not isinstance(getattr(class_, '__init__'), types.MethodType):
- def __init__(self, **kwargs):
- for key, value in kwargs.items():
- if validate:
- if not self.mapper.get_property(key,
- resolve_synonyms=False,
- raiseerr=False):
- raise exceptions.ArgumentError(
- "Invalid __init__ argument: '%s'" % key)
- setattr(self, key, value)
- class_.__init__ = __init__
-
- class query(object):
- def __getattr__(self, key):
- return getattr(ctx.current.query(class_), key)
- def __call__(self):
- return ctx.current.query(class_)
-
- if not hasattr(class_, 'query'):
- class_.query = query()
-
- for name in ('get', 'filter', 'filter_by', 'select', 'select_by',
- 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by',
- 'get_by', 'join_to', 'join_via', 'count', 'count_by',
- 'options', 'instances'):
- _monkeypatch_query_method(name, ctx, class_)
- for name in ('refresh', 'expire', 'delete', 'expunge', 'update'):
- _monkeypatch_session_method(name, ctx, class_)
-
- m = mapper(class_, extension=extension, *args, **kwargs)
- class_.mapper = m
- return m
-
-assign_mapper = util.deprecated(
- "assign_mapper is deprecated. Use scoped_session() instead.")(assign_mapper)
def clear(self):
del self.col[0:len(self.col)]
- def __eq__(self, other): return list(self) == other
- def __ne__(self, other): return list(self) != other
- def __lt__(self, other): return list(self) < other
- def __le__(self, other): return list(self) <= other
- def __gt__(self, other): return list(self) > other
- def __ge__(self, other): return list(self) >= other
- def __cmp__(self, other): return cmp(list(self), other)
+ def __eq__(self, other):
+ return list(self) == other
+
+ def __ne__(self, other):
+ return list(self) != other
+
+ def __lt__(self, other):
+ return list(self) < other
+
+ def __le__(self, other):
+ return list(self) <= other
+
+ def __gt__(self, other):
+ return list(self) > other
+
+ def __ge__(self, other):
+ return list(self) >= other
+
+ def __cmp__(self, other):
+ return cmp(list(self), other)
def __add__(self, iterable):
try:
def clear(self):
self.col.clear()
- def __eq__(self, other): return dict(self) == other
- def __ne__(self, other): return dict(self) != other
- def __lt__(self, other): return dict(self) < other
- def __le__(self, other): return dict(self) <= other
- def __gt__(self, other): return dict(self) > other
- def __ge__(self, other): return dict(self) >= other
- def __cmp__(self, other): return cmp(dict(self), other)
+ def __eq__(self, other):
+ return dict(self) == other
+
+ def __ne__(self, other):
+ return dict(self) != other
+
+ def __lt__(self, other):
+ return dict(self) < other
+
+ def __le__(self, other):
+ return dict(self) <= other
+
+ def __gt__(self, other):
+ return dict(self) > other
+
+ def __ge__(self, other):
+ return dict(self) >= other
+
+ def __cmp__(self, other):
+ return cmp(dict(self), other)
def __repr__(self):
return repr(dict(self.items()))
def copy(self):
return util.Set(self)
- def __eq__(self, other): return util.Set(self) == other
- def __ne__(self, other): return util.Set(self) != other
- def __lt__(self, other): return util.Set(self) < other
- def __le__(self, other): return util.Set(self) <= other
- def __gt__(self, other): return util.Set(self) > other
- def __ge__(self, other): return util.Set(self) >= other
+ def __eq__(self, other):
+ return util.Set(self) == other
+
+ def __ne__(self, other):
+ return util.Set(self) != other
+
+ def __lt__(self, other):
+ return util.Set(self) < other
+
+ def __le__(self, other):
+ return util.Set(self) <= other
+
+ def __gt__(self, other):
+ return util.Set(self) > other
+
+ def __ge__(self, other):
+ return util.Set(self) >= other
def __repr__(self):
return repr(util.Set(self))
continue
prop = _deferred_relation(cls, value)
our_stuff[k] = prop
+
+ # set up attributes in the order they were created
+ our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order, our_stuff[y]._creation_order))
table = None
if '__table__' not in cls.__dict__:
mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
else:
mapper_cls = mapper
+
cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
return type.__init__(cls, classname, bases, dict_)
u = User()
u.topten.append(Blurb('Number one!'))
u.topten.append(Blurb('Number two!'))
-
+
# Like magic.
assert [blurb.position for blurb in u.topten] == [0, 1]
def ordering_list(attr, count_from=None, **kw):
"""Prepares an OrderingList factory for use in mapper definitions.
-
+
Returns an object suitable for use as an argument to a Mapper relation's
``collection_class`` option. Arguments are:
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
list in SQL, storing the value in the 'pos' column. Ignored if
``ordering_func`` is supplied.
-
+
Passes along any keyword arguments to ``OrderingList`` constructor.
"""
Keyword argument filter, prepares a simple ``ordering_func`` from a
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
"""
-
+
count_from = kw.pop('count_from', None)
if kw.get('ordering_func', None) is None and count_from is not None:
if count_from == 0:
``ordering_list`` function is used to configure ``OrderingList``
collections in ``mapper`` relation definitions.
"""
-
+
def __init__(self, ordering_attr=None, ordering_func=None,
reorder_on_append=False):
"""A custom list that manages position information for its children.
-
+
``OrderingList`` is a ``collection_class`` list implementation that
syncs position in a Python list with a position attribute on the
mapped objects.
An ``ordering_func`` is called with two positional parameters: the
index of the element in the list, and the list itself.
-
+
If omitted, Python list indexes are used for the attribute values.
Two basic pre-built numbering functions are provided in this module:
``count_from_0`` and ``count_from_1``. For more exotic examples
def _reorder(self):
"""Sweep through the list and ensure that each object has accurate
ordering information set."""
-
+
for index, entity in enumerate(self):
self._order_entity(index, entity, True)
return
should_be = self.ordering_func(index, self)
- if have <> should_be:
+ if have != should_be:
self._set_order_value(entity, should_be)
def append(self, entity):
entity = super(OrderingList, self).pop(index)
self._reorder()
return entity
-
+
def __setitem__(self, index, entity):
if isinstance(index, slice):
for i in range(index.start or 0, index.stop or 0, index.step or 1):
else:
self._order_entity(index, entity, True)
super(OrderingList, self).__setitem__(index, entity)
-
+
def __delitem__(self, index):
super(OrderingList, self).__delitem__(index)
self._reorder()
+++ /dev/null
-"""SelectResults has been rolled into Query. This class is now just a placeholder."""
-
-import sqlalchemy.sql as sql
-import sqlalchemy.orm as orm
-
-class SelectResultsExt(orm.MapperExtension):
- """a MapperExtension that provides SelectResults functionality for the
- results of query.select_by() and query.select()"""
-
- def select_by(self, query, *args, **params):
- q = query
- for a in args:
- q = q.filter(a)
- return q.filter_by(**params)
-
- def select(self, query, arg=None, **kwargs):
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- return orm.EXT_CONTINUE
- else:
- if arg is not None:
- query = query.filter(arg)
- return query._legacy_select_kwargs(**kwargs)
-
-def SelectResults(query, clause=None, ops={}):
- if clause is not None:
- query = query.filter(clause)
- query = query.options(orm.extension(SelectResultsExt()))
- return query._legacy_select_kwargs(**ops)
+++ /dev/null
-from sqlalchemy.orm.scoping import ScopedSession, _ScopedExt
-from sqlalchemy.util import warn_deprecated
-from sqlalchemy.orm import create_session
-
-__all__ = ['SessionContext', 'SessionContextExt']
-
-
-class SessionContext(ScopedSession):
- """Provides thread-local management of Sessions.
-
- Usage::
-
- context = SessionContext(sessionmaker(autoflush=True))
-
- """
-
- def __init__(self, session_factory=None, scopefunc=None):
- warn_deprecated("SessionContext is deprecated. Use scoped_session().")
- if session_factory is None:
- session_factory=create_session
- super(SessionContext, self).__init__(session_factory, scopefunc=scopefunc)
-
- def get_current(self):
- return self.registry()
-
- def set_current(self, session):
- self.registry.set(session)
-
- def del_current(self):
- self.registry.clear()
-
- current = property(get_current, set_current, del_current,
- """Property used to get/set/del the session in the current scope.""")
-
- def _get_mapper_extension(self):
- try:
- return self._extension
- except AttributeError:
- self._extension = ext = SessionContextExt(self)
- return ext
-
- mapper_extension = property(_get_mapper_extension,
- doc="""Get a mapper extension that implements `get_session` using this context. Deprecated.""")
-
-
-class SessionContextExt(_ScopedExt):
- def __init__(self, *args, **kwargs):
- warn_deprecated("SessionContextExt is deprecated. Use ScopedSession(enhance_classes=True)")
- super(SessionContextExt, self).__init__(*args, **kwargs)
-
Accessing the Session
---------------------
-SqlSoup uses a SessionContext to provide thread-local sessions. You
+SqlSoup uses a ScopedSession to provide thread-local sessions. You
can get a reference to the current one like this::
>>> from sqlalchemy.ext.sqlsoup import objectstore
from sqlalchemy import *
from sqlalchemy import schema, sql
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.scoping import ScopedSession
from sqlalchemy.exceptions import *
from sqlalchemy.sql import expression
#
# thread local SessionContext
#
-class Objectstore(SessionContext):
+class Objectstore(ScopedSession):
def __getattr__(self, key):
- return getattr(self.current, key)
+ if key.startswith('__'): # dont trip the registry for module-level sweeps of things
+ # like '__bases__'. the session gets bound to the
+ # module which is interfered with by other unit tests.
+ # (removal of mapper.get_session() revealed the issue)
+ raise AttributeError()
+ return getattr(self.registry(), key)
+ def current(self):
+ return self.registry()
+ current = property(current)
def get_session(self):
- return self.current
+ return self.registry()
objectstore = Objectstore(create_session)
-class PKNotFoundError(SQLAlchemyError): pass
+class PKNotFoundError(SQLAlchemyError):
+ pass
def _ddl_error(cls):
msg = 'SQLSoup can only modify mapped Tables (found: %s)' \
def _selectable_name(selectable):
if isinstance(selectable, sql.Alias):
- return _selectable_name(selectable.selectable)
+ return _selectable_name(selectable.element)
elif isinstance(selectable, sql.Select):
return ''.join([_selectable_name(s) for s in selectable.froms])
elif isinstance(selectable, schema.Table):
klass = TableClassType(mapname, (object,), {})
else:
klass = SelectableClassType(mapname, (object,), {})
-
+
def __cmp__(self, o):
L = self.__class__.c.keys()
L.sort()
for m in ['__cmp__', '__repr__']:
setattr(klass, m, eval(m))
klass._table = selectable
+ klass.c = expression.ColumnCollection()
mappr = mapper(klass,
selectable,
- extension=objectstore.mapper_extension,
+ extension=objectstore.extension,
allow_null_pks=_is_outer_join(selectable),
**mapper_kwargs)
- klass._query = Query(mappr)
+
+ for k in mappr.iterate_properties:
+ klass.c[k.key] = k.columns[0]
+
+ klass._query = objectstore.query_property()
return klass
class SqlSoup:
The ``_ConnectionFairy`` which manages the connection for the span of
the current checkout.
- If you raise an ``exceptions.DisconnectionError``, the current
+ If you raise an ``exc.DisconnectionError``, the current
connection will be disposed and a fresh connection retrieved.
Processing of all checkout listeners will abort and restart
using the new connection.
The ``_ConnectionRecord`` that persistently manages the connection
"""
+
+class ConnectionProxy(object):
+ """Allows interception of statement execution by Connections.
+
+ Subclass ``ConnectionProxy``, overriding either or both of
+ ``execute()`` and ``cursor_execute()`` The default behavior is provided,
+ which is to call the given executor function with the remaining
+ arguments. The proxy is then connected to an engine via
+ ``create_engine(url, proxy=MyProxy())`` where ``MyProxy`` is
+ the user-defined ``ConnectionProxy`` class.
+
+ """
+ def execute(self, conn, execute, clauseelement, *multiparams, **params):
+ """"""
+ return execute(clauseelement, *multiparams, **params)
+
+ def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+ """"""
+ return execute(cursor, statement, parameters, context)
+
+
-# logging.py - adapt python logging module to SQLAlchemy
+# log.py - adapt python logging module to SQLAlchemy
# Copyright (C) 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
logger.setLevel(logging.DEBUG)
"""
-import sys, warnings
-import sqlalchemy.exceptions as sa_exc
+import logging
+import sys
-# py2.5 absolute imports will fix....
-logging = __import__('logging')
-
-# moved to sqlalchemy.exceptions. this alias will be removed in 0.5.
-SADeprecationWarning = sa_exc.SADeprecationWarning
rootlogger = logging.getLogger('sqlalchemy')
if rootlogger.level == logging.NOTSET:
rootlogger.setLevel(logging.WARN)
-warnings.filterwarnings("once", category=sa_exc.SADeprecationWarning)
default_enabled = False
def default_logging(name):
global default_enabled
if logging.getLogger(name).getEffectiveLevel() < logging.WARN:
- default_enabled=True
+ default_enabled = True
if not default_enabled:
default_enabled = True
handler = logging.StreamHandler(sys.stdout)
+++ /dev/null
-from sqlalchemy.ext.selectresults import SelectResultsExt
-from sqlalchemy.orm.mapper import global_extensions
-
-def install_plugin():
- global_extensions.append(SelectResultsExt)
-
-install_plugin()
See the SQLAlchemy object relational tutorial and mapper configuration
documentation for an overview of how this module is used.
+
"""
-from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, _mapper_registry
-from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, EXT_STOP, EXT_PASS, ExtensionOption, PropComparator
-from sqlalchemy.orm.properties import SynonymProperty, ComparableProperty, PropertyLoader, ColumnProperty, CompositeProperty, BackRef
+from sqlalchemy.orm import exc
+from sqlalchemy.orm.mapper import \
+ Mapper, _mapper_registry, class_mapper, object_mapper
+from sqlalchemy.orm.interfaces import \
+ EXT_CONTINUE, EXT_STOP, ExtensionOption, InstrumentationManager, \
+ MapperExtension, PropComparator, SessionExtension
+from sqlalchemy.orm.properties import \
+ BackRef, ColumnProperty, ComparableProperty, CompositeProperty, \
+ PropertyLoader, SynonymProperty
from sqlalchemy.orm import mapper as mapperlib
from sqlalchemy.orm import strategies
-from sqlalchemy.orm.query import Query, aliased
-from sqlalchemy.orm.util import polymorphic_union, create_row_adapter
+from sqlalchemy.orm.query import AliasOption, Query
+from sqlalchemy.orm.util import \
+ AliasedClass as aliased, join, outerjoin, polymorphic_union, with_parent
+from sqlalchemy.sql import util as sql_util
from sqlalchemy.orm.session import Session as _Session
from sqlalchemy.orm.session import object_session, sessionmaker
from sqlalchemy.orm.scoping import ScopedSession
-
-
-__all__ = [ 'relation', 'column_property', 'composite', 'backref', 'eagerload',
- 'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer',
- 'undefer', 'undefer_group', 'extension', 'mapper', 'clear_mappers',
- 'compile_mappers', 'class_mapper', 'object_mapper', 'sessionmaker',
- 'scoped_session', 'dynamic_loader', 'MapperExtension',
- 'polymorphic_union', 'comparable_property',
- 'create_session', 'synonym', 'contains_alias', 'Query', 'aliased',
- 'contains_eager', 'EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS',
- 'object_session', 'PropComparator' ]
+from sqlalchemy import util as sa_util
+
+__all__ = (
+ 'EXT_CONTINUE',
+ 'EXT_STOP',
+ 'InstrumentationManager',
+ 'MapperExtension',
+ 'PropComparator',
+ 'Query',
+ 'aliased',
+ 'backref',
+ 'class_mapper',
+ 'clear_mappers',
+ 'column_property',
+ 'comparable_property',
+ 'compile_mappers',
+ 'composite',
+ 'contains_alias',
+ 'contains_eager',
+ 'create_session',
+ 'defer',
+ 'deferred',
+ 'dynamic_loader',
+ 'eagerload',
+ 'eagerload_all',
+ 'extension',
+ 'lazyload',
+ 'mapper',
+ 'noload',
+ 'object_mapper',
+ 'object_session',
+ 'polymorphic_union',
+ 'relation',
+ 'scoped_session',
+ 'sessionmaker',
+ 'synonym',
+ 'undefer',
+ 'undefer_group',
+ )
def scoped_session(session_factory, scopefunc=None):
- """Provides thread-local management of Sessions.
-
- This is a front-end function to the [sqlalchemy.orm.scoping#ScopedSession]
- class.
+ """Provides thread-local management of Sessions.
- Usage::
+ This is a front-end function to the [sqlalchemy.orm.scoping#ScopedSession]
+ class.
- Session = scoped_session(sessionmaker(autoflush=True))
+ Usage::
- To instantiate a Session object which is part of the scoped
- context, instantiate normally::
+ Session = scoped_session(sessionmaker(autoflush=True))
- session = Session()
+ To instantiate a Session object which is part of the scoped context,
+ instantiate normally::
- Most session methods are available as classmethods from
- the scoped session::
+ session = Session()
- Session.commit()
- Session.close()
+ Most session methods are available as classmethods from the scoped
+ session::
- To map classes so that new instances are saved in the current
- Session automatically, as well as to provide session-aware
- class attributes such as "query", use the `mapper` classmethod
- from the scoped session::
+ Session.commit()
+ Session.close()
- mapper = Session.mapper
- mapper(Class, table, ...)
+ To map classes so that new instances are saved in the current Session
+ automatically, as well as to provide session-aware class attributes such
+ as "query", use the `mapper` classmethod from the scoped session::
- """
+ mapper = Session.mapper
+ mapper(Class, table, ...)
- return ScopedSession(session_factory, scopefunc=scopefunc)
+ """
+ return ScopedSession(session_factory, scopefunc=scopefunc)
def create_session(bind=None, **kwargs):
"""create a new [sqlalchemy.orm.session#Session].
It is recommended to use the [sqlalchemy.orm#sessionmaker()] function
instead of create_session().
"""
+
+ if 'transactional' in kwargs:
+ sa_util.warn_deprecated(
+ "The 'transactional' argument to sessionmaker() is deprecated; "
+ "use autocommit=True|False instead.")
+ if 'autocommit' in kwargs:
+ raise TypeError('Specify autocommit *or* transactional, not both.')
+ kwargs['autocommit'] = not kwargs.pop('transactional')
+
kwargs.setdefault('autoflush', False)
- kwargs.setdefault('transactional', False)
+ kwargs.setdefault('autocommit', True)
+ kwargs.setdefault('autoexpire', False)
return _Session(bind=bind, **kwargs)
def relation(argument, secondary=None, **kwargs):
"""Provide a relationship of a primary Mapper to a secondary Mapper.
- This corresponds to a parent-child or associative table relationship.
- The constructed class is an instance of [sqlalchemy.orm.properties#PropertyLoader].
+ This corresponds to a parent-child or associative table relationship. The
+ constructed class is an instance of
+ [sqlalchemy.orm.properties#PropertyLoader].
argument
a class or Mapper instance, representing the target of the relation.
secondary
for a many-to-many relationship, specifies the intermediary table. The
- ``secondary`` keyword argument should generally only be used for a table
- that is not otherwise expressed in any class mapping. In particular,
- using the Association Object Pattern is
- generally mutually exclusive against using the ``secondary`` keyword
- argument.
+ ``secondary`` keyword argument should generally only be used for a
+ table that is not otherwise expressed in any class mapping. In
+ particular, using the Association Object Pattern is generally mutually
+ exclusive against using the ``secondary`` keyword argument.
\**kwargs follow:
which will identify the class/mapper combination to be used
with a particular row. Requires the ``polymorphic_identity``
value to be set for all mappers in the inheritance
- hierarchy. The column specified by ``polymorphic_on`` is
- usually a column that resides directly within the base
+ hierarchy. The column specified by ``polymorphic_on`` is
+ usually a column that resides directly within the base
mapper's mapped table; alternatively, it may be a column
that is only present within the <selectable> portion
of the ``with_polymorphic`` argument.
to be used against this mapper's selectable unit. This is
normally simply the primary key of the `local_table`, but
can be overridden here.
-
+
with_polymorphic
A tuple in the form ``(<classes>, <selectable>)`` indicating the
default style of "polymorphic" loading, that is, which tables
which load from a "concrete" inheriting table, the <selectable>
argument is required, since it usually requires more complex
UNION queries.
-
+
select_table
- Deprecated. Synonymous with
+ Deprecated. Synonymous with
``with_polymorphic=('*', <selectable>)``.
version_id_col
return ExtensionOption(ext)
-def eagerload(name, mapper=None):
+def eagerload(*keys):
"""Return a ``MapperOption`` that will convert the property of the given name into an eager load.
Used with ``query.options()``.
"""
- return strategies.EagerLazyOption(name, lazy=False, mapper=mapper)
+ return strategies.EagerLazyOption(keys, lazy=False)
+eagerload = sa_util.array_as_starargs_fn_decorator(eagerload)
-def eagerload_all(name, mapper=None):
+def eagerload_all(*keys):
"""Return a ``MapperOption`` that will convert all properties along the given dot-separated path into an eager load.
For example, this::
Used with ``query.options()``.
"""
- return strategies.EagerLazyOption(name, lazy=False, chained=True, mapper=mapper)
+ return strategies.EagerLazyOption(keys, lazy=False, chained=True)
+eagerload_all = sa_util.array_as_starargs_fn_decorator(eagerload_all)
-def lazyload(name, mapper=None):
+def lazyload(*keys):
"""Return a ``MapperOption`` that will convert the property of the
given name into a lazy load.
Used with ``query.options()``.
"""
- return strategies.EagerLazyOption(name, lazy=True, mapper=mapper)
+ return strategies.EagerLazyOption(keys, lazy=True)
+lazyload = sa_util.array_as_starargs_fn_decorator(lazyload)
-def noload(name):
+def noload(*keys):
"""Return a ``MapperOption`` that will convert the property of the
given name into a non-load.
Used with ``query.options()``.
"""
- return strategies.EagerLazyOption(name, lazy=None)
+ return strategies.EagerLazyOption(keys, lazy=None)
def contains_alias(alias):
"""Return a ``MapperOption`` that will indicate to the query that
alias.
"""
- class AliasedRow(MapperExtension):
- def __init__(self, alias):
- self.alias = alias
- if isinstance(self.alias, basestring):
- self.translator = None
- else:
- self.translator = create_row_adapter(alias)
-
- def translate_row(self, mapper, context, row):
- if not self.translator:
- self.translator = create_row_adapter(mapper.mapped_table.alias(self.alias))
- return self.translator(row)
-
- return ExtensionOption(AliasedRow(alias))
+ return AliasOption(alias)
-def contains_eager(key, alias=None, decorator=None):
+def contains_eager(*keys, **kwargs):
"""Return a ``MapperOption`` that will indicate to the query that
the given attribute will be eagerly loaded.
`alias` is the string name of an alias, **or** an ``sql.Alias``
object, which represents the aliased columns in the query. This
argument is optional.
-
- `decorator` is mutually exclusive of `alias` and is a
- row-processing function which will be applied to the incoming row
- before sending to the eager load handler. use this for more
- sophisticated row adjustments beyond a straight alias.
"""
+ alias = kwargs.pop('alias', None)
+ if kwargs:
+ raise exceptions.ArgumentError("Invalid kwargs for contains_eager: %r" % kwargs.keys())
+
+ return (strategies.EagerLazyOption(keys, lazy=False), strategies.LoadEagerFromAliasOption(keys, alias=alias))
+contains_eager = sa_util.array_as_starargs_fn_decorator(contains_eager)
- return (strategies.EagerLazyOption(key, lazy=False), strategies.RowDecorateOption(key, alias=alias, decorator=decorator))
-
-def defer(name):
+def defer(*keys):
"""Return a ``MapperOption`` that will convert the column property
of the given name into a deferred load.
Used with ``query.options()``"""
- return strategies.DeferredOption(name, defer=True)
+ return strategies.DeferredOption(keys, defer=True)
+defer = sa_util.array_as_starargs_fn_decorator(defer)
-def undefer(name):
+def undefer(*keys):
"""Return a ``MapperOption`` that will convert the column property
of the given name into a non-deferred (regular column) load.
Used with ``query.options()``.
"""
- return strategies.DeferredOption(name, defer=False)
+ return strategies.DeferredOption(keys, defer=False)
+undefer = sa_util.array_as_starargs_fn_decorator(undefer)
def undefer_group(name):
"""Return a ``MapperOption`` that will convert the given
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import operator, weakref
-from itertools import chain
-import UserDict
+import operator
+import weakref
+
from sqlalchemy import util
+from sqlalchemy.util import attrgetter, itemgetter, EMPTY_SET
from sqlalchemy.orm import interfaces, collections
-from sqlalchemy.orm.util import identity_equal
-from sqlalchemy import exceptions
+import sqlalchemy.exceptions as sa_exc
+
+# lazy imports
+_entity_info = None
+identity_equal = None
PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT')
ATTR_WAS_SET = util.symbol('ATTR_WAS_SET')
NO_VALUE = util.symbol('NO_VALUE')
NEVER_SET = util.symbol('NEVER_SET')
+NO_ENTITY_NAME = util.symbol('NO_ENTITY_NAME')
-class InstrumentedAttribute(interfaces.PropComparator):
- """public-facing instrumented attribute, placed in the
- class dictionary.
+INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__'
+"""Attribute, elects custom instrumentation when present on a mapped class.
- """
+Allows a class to specify a slightly or wildly different technique for
+tracking changes made to mapped attributes and collections.
+
+Only one instrumentation implementation is allowed in a given object
+inheritance hierarchy.
+
+The value of this attribute must be a callable and will be passed a class
+object. The callable must return one of:
+
+ - An instance of an interfaces.InstrumentationManager or subclass
+ - An object implementing all or some of InstrumentationManager (todo)
+ - A dictionary of callables, implementing all or some of the above (todo)
+ - An instance of a ClassManager or subclass
+
+interfaces.InstrumentationManager is public API and will remain stable
+between releases. ClassManager is not public and no guarantees are made
+about stability. Caveat emptor.
+
+This attribute is consulted by the default SQLAlchemy instrumentation
+resultion code. If custom finders are installed in the global
+instrumentation_finders list, they may or may not choose to honor this
+attribute.
+
+"""
- def __init__(self, impl, comparator=None):
+instrumentation_finders = []
+"""An extensible sequence of instrumentation implementation finding callables.
+
+Finders callables will be passed a class object. If None is returned, the
+next finder in the sequence is consulted. Otherwise the return must be an
+instrumentation factory that follows the same guidelines as
+INSTRUMENTATION_MANAGER.
+
+By default, the only finder is find_native_user_instrumentation_hook, which
+searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
+ClassManager instrumentation is used.
+
+"""
+
+class QueryableAttribute(interfaces.PropComparator):
+
+ def __init__(self, impl, comparator=None, parententity=None):
"""Construct an InstrumentedAttribute.
comparator
a sql.Comparator to which class-level compare/math events will be sent
self.impl = impl
self.comparator = comparator
+ self.parententity = parententity
- def __set__(self, instance, value):
- self.impl.set(instance._state, value, None)
-
- def __delete__(self, instance):
- self.impl.delete(instance._state)
-
- def __get__(self, instance, owner):
- if instance is None:
- return self
- return self.impl.get(instance._state)
+ if parententity:
+ mapper, selectable, is_aliased_class = _entity_info(parententity, compile=False)
+ self.property = mapper._get_property(self.impl.key)
+ else:
+ self.property = None
def get_history(self, instance, **kwargs):
- return self.impl.get_history(instance._state, **kwargs)
-
- def clause_element(self):
- return self.comparator.clause_element()
-
- def expression_element(self):
- return self.comparator.expression_element()
-
+ return self.impl.get_history(instance_state(instance), **kwargs)
+
+ def __selectable__(self):
+ # TODO: conditionally attach this method based on clause_element ?
+ return self
+
+ def __clause_element__(self):
+ return self.comparator.__clause_element__()
+
+ def label(self, name):
+ return self.__clause_element__().label(name)
+
def operate(self, op, *other, **kwargs):
return op(self.comparator, *other, **kwargs)
def reverse_operate(self, op, other, **kwargs):
return op(other, self.comparator, **kwargs)
- def hasparent(self, instance, optimistic=False):
- return self.impl.hasparent(instance._state, optimistic=optimistic)
-
- def _property(self):
- from sqlalchemy.orm.mapper import class_mapper
- return class_mapper(self.impl.class_).get_property(self.impl.key)
- property = property(_property, doc="the MapperProperty object associated with this attribute")
-
-class ProxiedAttribute(InstrumentedAttribute):
- """Adds InstrumentedAttribute class-level behavior to a regular descriptor.
-
- Obsoleted by proxied_attribute_factory.
- """
+ def hasparent(self, state, optimistic=False):
+ return self.impl.hasparent(state, optimistic=optimistic)
- class ProxyImpl(object):
- accepts_scalar_loader = False
+ def __str__(self):
+ return repr(self.parententity) + "." + self.property.key
- def __init__(self, key):
- self.key = key
+class InstrumentedAttribute(QueryableAttribute):
+ """Public-facing descriptor, placed in the mapped class dictionary."""
- def __init__(self, key, user_prop, comparator=None):
- self.user_prop = user_prop
- self._comparator = comparator
- self.key = key
- self.impl = ProxiedAttribute.ProxyImpl(key)
+ def __set__(self, instance, value):
+ self.impl.set(instance_state(instance), value, None)
- def comparator(self):
- if callable(self._comparator):
- self._comparator = self._comparator()
- return self._comparator
- comparator = property(comparator)
+ def __delete__(self, instance):
+ self.impl.delete(instance_state(instance))
def __get__(self, instance, owner):
if instance is None:
- self.user_prop.__get__(instance, owner)
return self
- return self.user_prop.__get__(instance, owner)
-
- def __set__(self, instance, value):
- return self.user_prop.__set__(instance, value)
-
- def __delete__(self, instance):
- return self.user_prop.__delete__(instance)
+ return self.impl.get(instance_state(instance))
def proxied_attribute_factory(descriptor):
"""Create an InstrumentedAttribute / user descriptor hybrid.
class ProxyImpl(object):
accepts_scalar_loader = False
+
def __init__(self, key):
self.key = key
-
+
class Proxy(InstrumentedAttribute):
"""A combination of InsturmentedAttribute and a regular descriptor."""
- def __init__(self, key, descriptor, comparator):
+ def __init__(self, key, descriptor, comparator, parententity):
self.key = key
# maintain ProxiedAttribute.user_prop compatability.
self.descriptor = self.user_prop = descriptor
self._comparator = comparator
+ self._parententity = parententity
self.impl = ProxyImpl(key)
def comparator(self):
def __getattr__(self, attribute):
"""Delegate __getattr__ to the original descriptor."""
return getattr(descriptor, attribute)
+
+ def _property(self):
+ return self._parententity.get_property(self.key, resolve_synonyms=True)
+ property = property(_property)
+
Proxy.__name__ = type(descriptor).__name__ + 'Proxy'
util.monkeypatch_proxied_specials(Proxy, type(descriptor),
class AttributeImpl(object):
"""internal implementation for instrumented attributes."""
- def __init__(self, class_, key, callable_, trackparent=False, extension=None, compare_function=None, **kwargs):
+ def __init__(self, class_, key, callable_, class_manager, trackparent=False, extension=None, compare_function=None, **kwargs):
"""Construct an AttributeImpl.
class_
self.class_ = class_
self.key = key
self.callable_ = callable_
+ self.class_manager = class_manager
self.trackparent = trackparent
if compare_function is None:
self.is_equal = operator.eq
An instance attribute that is loaded by a callable function
will also not have a `hasparent` flag.
- """
+ """
return state.parents.get(id(self), optimistic)
def sethasparent(self, state, value):
"""Set a boolean flag on the given item corresponding to
whether or not it is attached to a parent object via the
attribute represented by this ``InstrumentedAttribute``.
- """
+ """
state.parents[id(self)] = value
def set_callable(self, state, callable_):
The callable overrides the class level callable set in the
``InstrumentedAttribute` constructor.
- """
+ """
if callable_ is None:
self.initialize(state)
else:
if self.key in state.callables:
return state.callables[self.key]
elif self.callable_ is not None:
- return self.callable_(state.obj())
+ return self.callable_(state)
else:
return None
return state.dict[self.key]
except KeyError:
# if no history, check for lazy callables, etc.
- if self.key not in state.committed_state:
+ if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET:
callable_ = self._get_callable(state)
if callable_ is not None:
if passive:
def set_committed_value(self, state, value):
"""set an attribute value on the given instance and 'commit' it."""
- state.commit_attr(self, value)
+ state.commit([self.key])
+
+ state.callables.pop(self.key, None)
+ state.dict[self.key] = value
+
return value
class ScalarAttributeImpl(AttributeImpl):
"""represents a scalar value-holding InstrumentedAttribute."""
accepts_scalar_loader = True
+ uses_objects = False
def delete(self, state):
- if self.key not in state.committed_state:
- state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE)
+ state.modified_event(self, False, state.dict.get(self.key, NO_VALUE))
# TODO: catch key errors, convert to attributeerror?
- del state.dict[self.key]
- state.modified=True
+ if self.extensions:
+ old = self.get(state)
+ del state.dict[self.key]
+ self.fire_remove_event(state, old, None)
+ else:
+ del state.dict[self.key]
def get_history(self, state, passive=False):
- return _create_history(self, state, state.dict.get(self.key, NO_VALUE))
+ return History.from_attribute(
+ self, state, state.dict.get(self.key, NO_VALUE))
def set(self, state, value, initiator):
if initiator is self:
return
- if self.key not in state.committed_state:
- state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE)
+ state.modified_event(self, False, state.dict.get(self.key, NO_VALUE))
- state.dict[self.key] = value
- state.modified=True
+ if self.extensions:
+ old = self.get(state)
+ state.dict[self.key] = value
+ self.fire_replace_event(state, value, old, initiator)
+ else:
+ state.dict[self.key] = value
+
+ def fire_replace_event(self, state, value, previous, initiator):
+ for ext in self.extensions:
+ ext.set(state, value, previous, initiator or self)
+
+ def fire_remove_event(self, state, value, initiator):
+ for ext in self.extensions:
+ ext.remove(state, value, initiator or self)
def type(self):
self.property.columns[0].type
changes within the value itself.
"""
- def __init__(self, class_, key, callable_, copy_function=None, compare_function=None, **kwargs):
- super(ScalarAttributeImpl, self).__init__(class_, key, callable_, compare_function=compare_function, **kwargs)
- class_._class_state.has_mutable_scalars = True
+ uses_objects = False
+
+ def __init__(self, class_, key, callable_, class_manager, copy_function=None, compare_function=None, **kwargs):
+ super(ScalarAttributeImpl, self).__init__(class_, key, callable_, class_manager, compare_function=compare_function, **kwargs)
+ class_manager.mutable_attributes.add(key)
if copy_function is None:
- raise exceptions.ArgumentError("MutableScalarAttributeImpl requires a copy function")
+ raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function")
self.copy = copy_function
def get_history(self, state, passive=False):
- return _create_history(self, state, state.dict.get(self.key, NO_VALUE))
+ return History.from_attribute(
+ self, state, state.dict.get(self.key, NO_VALUE))
- def commit_to_state(self, state, value):
- state.committed_state[self.key] = self.copy(value)
+ def commit_to_state(self, state, dest):
+ dest[self.key] = self.copy(state.dict[self.key])
def check_mutable_modified(self, state):
(added, unchanged, deleted) = self.get_history(state, passive=True)
- if added or deleted:
- state.modified = True
- return True
- else:
- return False
+ return bool(added or deleted)
def set(self, state, value, initiator):
if initiator is self:
return
- if self.key not in state.committed_state:
- if self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
- else:
- state.committed_state[self.key] = NO_VALUE
+ state.modified_event(self, True, NEVER_SET)
- state.dict[self.key] = value
- state.modified=True
+ if self.extensions:
+ old = self.get(state)
+ state.dict[self.key] = value
+ self.fire_replace_event(state, value, old, initiator)
+ else:
+ state.dict[self.key] = value
class ScalarObjectAttributeImpl(ScalarAttributeImpl):
"""
accepts_scalar_loader = False
+ uses_objects = True
- def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
+ def __init__(self, class_, key, callable_, class_manager, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
super(ScalarObjectAttributeImpl, self).__init__(class_, key,
- callable_, trackparent=trackparent, extension=extension,
+ callable_, class_manager, trackparent=trackparent, extension=extension,
compare_function=compare_function, **kwargs)
if compare_function is None:
self.is_equal = identity_equal
def get_history(self, state, passive=False):
if self.key in state.dict:
- return _create_history(self, state, state.dict[self.key])
+ return History.from_attribute(self, state, state.dict[self.key])
else:
current = self.get(state, passive=passive)
if current is PASSIVE_NORESULT:
return (None, None, None)
else:
- return _create_history(self, state, current)
+ return History.from_attribute(self, state, current)
def set(self, state, value, initiator):
"""Set a value on the given InstanceState.
if initiator is self:
return
-
- if value is not None and not hasattr(value, '_state'):
- raise TypeError("Can not assign %s instance to %s's %r attribute, "
- "a mapped instance was expected." % (
- type(value).__name__, type(state.obj()).__name__, self.key))
-
- # TODO: add options to allow the get() to be passive
+
+ # may want to add options to allow the get() here to be passive
old = self.get(state)
state.dict[self.key] = value
self.fire_replace_event(state, value, old, initiator)
def fire_remove_event(self, state, value, initiator):
- if self.key not in state.committed_state:
- state.committed_state[self.key] = value
- state.modified = True
+ state.modified_event(self, False, value)
if self.trackparent and value is not None:
- self.sethasparent(value._state, False)
+ self.sethasparent(instance_state(value), False)
- instance = state.obj()
for ext in self.extensions:
- ext.remove(instance, value, initiator or self)
+ ext.remove(state, value, initiator or self)
def fire_replace_event(self, state, value, previous, initiator):
- if self.key not in state.committed_state:
- state.committed_state[self.key] = previous
- state.modified = True
+ state.modified_event(self, False, previous)
if self.trackparent:
if value is not None:
- self.sethasparent(value._state, True)
+ self.sethasparent(instance_state(value), True)
if previous is not value and previous is not None:
- self.sethasparent(previous._state, False)
+ self.sethasparent(instance_state(previous), False)
- instance = state.obj()
for ext in self.extensions:
- ext.set(instance, value, previous, initiator or self)
+ ext.set(state, value, previous, initiator or self)
+
class CollectionAttributeImpl(AttributeImpl):
"""A collection-holding attribute that instruments changes in membership.
container object (defaulting to a list) and brokers access to the
CollectionAdapter, a "view" onto that object that presents consistent
bag semantics to the orm layer independent of the user data implementation.
+
"""
accepts_scalar_loader = False
+ uses_objects = True
- def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
+ def __init__(self, class_, key, callable_, class_manager, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
super(CollectionAttributeImpl, self).__init__(class_,
- key, callable_, trackparent=trackparent, extension=extension,
- compare_function=compare_function, **kwargs)
+ key, callable_, class_manager, trackparent=trackparent,
+ extension=extension, compare_function=compare_function, **kwargs)
if copy_function is None:
copy_function = self.__copy
self.copy = copy_function
- if typecallable is None:
- typecallable = list
- self.collection_factory = \
- collections._prepare_instrumentation(typecallable)
+ self.collection_factory = typecallable
# may be removed in 0.5:
self.collection_interface = \
util.duck_type_collection(self.collection_factory())
if current is PASSIVE_NORESULT:
return (None, None, None)
else:
- return _create_history(self, state, current)
+ return History.from_attribute(self, state, current)
def fire_append_event(self, state, value, initiator):
- if self.key not in state.committed_state and self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
-
- state.modified = True
+ state.modified_event(self, True, NEVER_SET, passive=True)
if self.trackparent and value is not None:
- self.sethasparent(value._state, True)
- instance = state.obj()
+ self.sethasparent(instance_state(value), True)
+
for ext in self.extensions:
- ext.append(instance, value, initiator or self)
+ ext.append(state, value, initiator or self)
def fire_pre_remove_event(self, state, initiator):
- if self.key not in state.committed_state and self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
+ state.modified_event(self, True, NEVER_SET, passive=True)
def fire_remove_event(self, state, value, initiator):
- if self.key not in state.committed_state and self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
-
- state.modified = True
+ state.modified_event(self, True, NEVER_SET, passive=True)
if self.trackparent and value is not None:
- self.sethasparent(value._state, False)
+ self.sethasparent(instance_state(value), False)
- instance = state.obj()
for ext in self.extensions:
- ext.remove(instance, value, initiator or self)
+ ext.remove(state, value, initiator or self)
def delete(self, state):
if self.key not in state.dict:
return
- state.modified = True
+ state.modified_event(self, True, NEVER_SET)
collection = self.get_collection(state)
collection.clear_with_event()
def initialize(self, state):
"""Initialize this attribute on the given object instance with an empty collection."""
- _, user_data = self._build_collection(state)
+ _, user_data = self._initialize_collection(state)
state.dict[self.key] = user_data
return user_data
+ def _initialize_collection(self, state):
+ return state.manager.initialize_collection(
+ self.key, state, self.collection_factory)
+
def append(self, state, value, initiator, passive=False):
if initiator is self:
return
"""
# pulling a new collection first so that an adaptation exception does
# not trigger a lazy load of the old collection.
- new_collection, user_data = self._build_collection(state)
+ new_collection, user_data = self._initialize_collection(state)
if adapter:
new_values = list(adapter(new_collection, iterable))
else:
if old is iterable:
return
- if self.key not in state.committed_state:
- state.committed_state[self.key] = self.copy(old)
+ state.modified_event(self, True, old)
old_collection = self.get_collection(state, old)
state.dict[self.key] = user_data
- state.modified = True
collections.bulk_replace(new_values, old_collection, new_collection)
old_collection.unlink(old)
def set_committed_value(self, state, value):
- """Set an attribute value on the given instance and 'commit' it.
-
- Loads the existing collection from lazy callables in all cases.
- """
+ """Set an attribute value on the given instance and 'commit' it."""
- collection, user_data = self._build_collection(state)
+ collection, user_data = self._initialize_collection(state)
if value:
for item in value:
state.callables.pop(self.key, None)
state.dict[self.key] = user_data
+ state.commit([self.key])
+
if self.key in state.pending:
- # pending items. commit loaded data, add/remove new data
- state.committed_state[self.key] = list(value or [])
- added = state.pending[self.key].added_items
- removed = state.pending[self.key].deleted_items
+ # pending items exist. issue a modified event,
+ # add/remove new items.
+ state.modified_event(self, True, user_data)
+
+ pending = state.pending.pop(self.key)
+ added = pending.added_items
+ removed = pending.deleted_items
for item in added:
collection.append_without_event(item)
for item in removed:
collection.remove_without_event(item)
- del state.pending[self.key]
- elif self.key in state.committed_state:
- # no pending items. remove committed state if any.
- # (this can occur with an expired attribute)
- del state.committed_state[self.key]
return user_data
- def _build_collection(self, state):
- """build a new, blank collection and return it wrapped in a CollectionAdapter."""
-
- user_data = self.collection_factory()
- collection = collections.CollectionAdapter(self, state, user_data)
- return collection, user_data
-
def get_collection(self, state, user_data=None, passive=False):
"""retrieve the CollectionAdapter associated with the given state.
user_data = self.get(state, passive=passive)
if user_data is PASSIVE_NORESULT:
return user_data
- try:
- return getattr(user_data, '_sa_adapter')
- except AttributeError:
- # TODO: this codepath never occurs, and this
- # except/initialize should be removed
- collections.CollectionAdapter(self, state, user_data)
- return getattr(user_data, '_sa_adapter')
+
+ return getattr(user_data, '_sa_adapter')
class GenericBackrefExtension(interfaces.AttributeExtension):
"""An extension which synchronizes a two-way relationship.
def __init__(self, key):
self.key = key
- def set(self, instance, child, oldchild, initiator):
+ def set(self, state, child, oldchild, initiator):
if oldchild is child:
return
if oldchild is not None:
# With lazy=None, there's no guarantee that the full collection is
# present when updating via a backref.
- impl = getattr(oldchild.__class__, self.key).impl
+ old_state = instance_state(oldchild)
+ impl = old_state.get_impl(self.key)
try:
- impl.remove(oldchild._state, instance, initiator, passive=True)
+ impl.remove(old_state, state.obj(), initiator, passive=True)
except (ValueError, KeyError, IndexError):
pass
if child is not None:
- getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True)
+ new_state = instance_state(child)
+ new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=True)
- def append(self, instance, child, initiator):
- getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True)
+ def append(self, state, child, initiator):
+ child_state = instance_state(child)
+ child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=True)
- def remove(self, instance, child, initiator):
+ def remove(self, state, child, initiator):
if child is not None:
- getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator, passive=True)
-
-class ClassState(object):
- """tracks state information at the class level."""
- def __init__(self):
- self.mappers = {}
- self.attrs = {}
- self.has_mutable_scalars = False
+ child_state = instance_state(child)
+ child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=True)
-import sets
-_empty_set = sets.ImmutableSet()
class InstanceState(object):
"""tracks state information at the instance level."""
- def __init__(self, obj):
+ _cleanup = None
+ session_id = None
+ key = None
+ runid = None
+ entity_name = NO_ENTITY_NAME
+ expired_attributes = EMPTY_SET
+
+ def __init__(self, obj, manager):
self.class_ = obj.__class__
- self.obj = weakref.ref(obj, self.__cleanup)
+ self.manager = manager
+ self.obj = weakref.ref(obj, self._cleanup)
self.dict = obj.__dict__
self.committed_state = {}
self.modified = False
self.callables = {}
self.parents = {}
self.pending = {}
- self.appenders = {}
- self.instance_dict = None
- self.runid = None
- self.expired_attributes = _empty_set
-
- def __cleanup(self, ref):
- # tiptoe around Python GC unpredictableness
- instance_dict = self.instance_dict
- if instance_dict is None:
- return
+ self.expired = False
+
+ def dispose(self):
+ del self.session_id
+
+ def check_modified(self):
+ if self.modified:
+ return True
+ else:
+ for key in self.manager.mutable_attributes:
+ if self.manager[key].impl.check_mutable_modified(self):
+ return True
+ else:
+ return False
- instance_dict = instance_dict()
- if instance_dict is None or instance_dict._mutex is None:
- return
+ def initialize_instance(*mixed, **kwargs):
+ self, instance, args = mixed[0], mixed[1], mixed[2:]
+ manager = self.manager
- # the mutexing here is based on the assumption that gc.collect()
- # may be firing off cleanup handlers in a different thread than that
- # which is normally operating upon the instance dict.
- instance_dict._mutex.acquire()
+ for fn in manager.events.on_init:
+ fn(self, instance, args, kwargs)
try:
- try:
- self.__resurrect(instance_dict)
- except:
- # catch app cleanup exceptions. no other way around this
- # without warnings being produced
- pass
- finally:
- instance_dict._mutex.release()
+ return manager.events.original_init(*mixed[1:], **kwargs)
+ except:
+ for fn in manager.events.on_init_failure:
+ fn(self, instance, args, kwargs)
+ raise
- def _check_resurrect(self, instance_dict):
- instance_dict._mutex.acquire()
- try:
- return self.obj() or self.__resurrect(instance_dict)
- finally:
- instance_dict._mutex.release()
+ def get_history(self, key, **kwargs):
+ return self.manager.get_impl(key).get_history(self, **kwargs)
+
+ def get_impl(self, key):
+ return self.manager.get_impl(key)
+
+ def get_inst(self, key):
+ return self.manager.get_inst(key)
def get_pending(self, key):
if key not in self.pending:
self.pending[key] = PendingCollection()
return self.pending[key]
- def is_modified(self):
- if self.modified:
- return True
- elif self.class_._class_state.has_mutable_scalars:
- for attr in _managed_attributes(self.class_):
- if hasattr(attr.impl, 'check_mutable_modified') and attr.impl.check_mutable_modified(self):
- return True
- else:
- return False
- else:
- return False
+ def value_as_iterable(self, key, passive=False):
+ """return an InstanceState attribute as a list,
+ regardless of it being a scalar or collection-based
+ attribute.
- def __resurrect(self, instance_dict):
- if self.is_modified():
- # store strong ref'ed version of the object; will revert
- # to weakref when changes are persisted
- obj = new_instance(self.class_, state=self)
- self.obj = weakref.ref(obj, self.__cleanup)
- self._strong_obj = obj
- obj.__dict__.update(self.dict)
- self.dict = obj.__dict__
- return obj
- else:
- del instance_dict[self.dict['_instance_key']]
+ returns None if passive=True and the getter returns
+ PASSIVE_NORESULT.
+ """
+
+ impl = self.get_impl(key)
+ x = impl.get(self, passive=passive)
+ if x is PASSIVE_NORESULT:
return None
+ elif hasattr(impl, 'get_collection'):
+ return impl.get_collection(self, x, passive=passive)
+ elif isinstance(x, list):
+ return x
+ else:
+ return [x]
+
+ def _run_on_load(self, instance=None):
+ if instance is None:
+ instance = self.obj()
+ self.manager.events.run('on_load', instance)
def __getstate__(self):
- return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':self.expired_attributes, 'callables':self.callables}
+ return {'key': self.key,
+ 'entity_name': self.entity_name,
+ 'committed_state': self.committed_state,
+ 'pending': self.pending,
+ 'parents': self.parents,
+ 'modified': self.modified,
+ 'expired':self.expired,
+ 'instance': self.obj(),
+ 'expired_attributes':self.expired_attributes,
+ 'callables': self.callables}
def __setstate__(self, state):
self.committed_state = state['committed_state']
self.parents = state['parents']
+ self.key = state['key']
+ self.session_id = None
+ self.entity_name = state['entity_name']
self.pending = state['pending']
self.modified = state['modified']
self.obj = weakref.ref(state['instance'])
self.class_ = self.obj().__class__
+ self.manager = manager_of_class(self.class_)
self.dict = self.obj().__dict__
self.callables = state['callables']
self.runid = None
- self.appenders = {}
+ self.expired = state['expired']
self.expired_attributes = state['expired_attributes']
def initialize(self, key):
- getattr(self.class_, key).impl.initialize(self)
+ self.manager.get_impl(key).initialize(self)
def set_callable(self, key, callable_):
self.dict.pop(key, None)
"""__call__ allows the InstanceState to act as a deferred
callable for loading expired attributes, which is also
serializable.
+
"""
- instance = self.obj()
unmodified = self.unmodified
- self.class_._class_state.deferred_scalar_loader(instance, [
- attr.impl.key for attr in _managed_attributes(self.class_) if
+ class_manager = self.manager
+ class_manager.deferred_scalar_loader(self, [
+ attr.impl.key for attr in class_manager.attributes if
attr.impl.accepts_scalar_loader and
attr.impl.key in self.expired_attributes and
attr.impl.key in unmodified
])
for k in self.expired_attributes:
self.callables.pop(k, None)
- self.expired_attributes.clear()
+ del self.expired_attributes
return ATTR_WAS_SET
def unmodified(self):
"""a set of keys which have no uncommitted changes"""
return util.Set([
- attr.impl.key for attr in _managed_attributes(self.class_) if
- attr.impl.key not in self.committed_state
- and (not hasattr(attr.impl, 'commit_to_state') or not attr.impl.check_mutable_modified(self))
- ])
+ key for key in self.manager.keys() if
+ key not in self.committed_state
+ or (key in self.manager.mutable_attributes and not self.manager[key].impl.check_mutable_modified(self))
+ ])
unmodified = property(unmodified)
def expire_attributes(self, attribute_names):
self.expired_attributes = util.Set(self.expired_attributes)
if attribute_names is None:
- for attr in _managed_attributes(self.class_):
- self.dict.pop(attr.impl.key, None)
- self.expired_attributes.add(attr.impl.key)
- if attr.impl.accepts_scalar_loader:
- self.callables[attr.impl.key] = self
-
- self.committed_state = {}
- else:
- for key in attribute_names:
- self.dict.pop(key, None)
- self.committed_state.pop(key, None)
- self.expired_attributes.add(key)
- if getattr(self.class_, key).impl.accepts_scalar_loader:
- self.callables[key] = self
+ attribute_names = self.manager.keys()
+ self.expired = True
+ self.modified = False
+ for key in attribute_names:
+ self.dict.pop(key, None)
+ self.committed_state.pop(key, None)
+ self.expired_attributes.add(key)
+ if self.manager.get_impl(key).accepts_scalar_loader:
+ self.callables[key] = self
def reset(self, key):
"""remove the given attribute and any callables associated with it."""
+
self.dict.pop(key, None)
self.callables.pop(key, None)
-
- def commit_attr(self, attr, value):
- """set the value of an attribute and mark it 'committed'."""
-
- if hasattr(attr, 'commit_to_state'):
- attr.commit_to_state(self, value)
- else:
- self.committed_state.pop(attr.key, None)
- self.dict[attr.key] = value
- self.pending.pop(attr.key, None)
- self.appenders.pop(attr.key, None)
-
- # we have a value so we can also unexpire it
- self.callables.pop(attr.key, None)
- if attr.key in self.expired_attributes:
- self.expired_attributes.remove(attr.key)
-
+
+ def modified_event(self, attr, should_copy, previous, passive=False):
+ needs_committed = attr.key not in self.committed_state
+
+ if needs_committed:
+ if previous is NEVER_SET:
+ if passive:
+ if attr.key in self.dict:
+ previous = self.dict[attr.key]
+ else:
+ previous = attr.get(self)
+
+ if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
+ previous = attr.copy(previous)
+
+ if needs_committed:
+ self.committed_state[attr.key] = previous
+
+ self.modified = True
+
def commit(self, keys):
"""commit all attributes named in the given list of key names.
if a value was not populated in state.dict.
"""
- if self.class_._class_state.has_mutable_scalars:
- for key in keys:
- attr = getattr(self.class_, key).impl
- if hasattr(attr, 'commit_to_state') and attr.key in self.dict:
- attr.commit_to_state(self, self.dict[attr.key])
- else:
- self.committed_state.pop(attr.key, None)
- self.pending.pop(key, None)
- self.appenders.pop(key, None)
- else:
- for key in keys:
+ class_manager = self.manager
+ for key in keys:
+ if key in self.dict and key in class_manager.mutable_attributes:
+ class_manager[key].impl.commit_to_state(self, self.committed_state)
+ else:
self.committed_state.pop(key, None)
- self.pending.pop(key, None)
- self.appenders.pop(key, None)
+ self.expired = False
# unexpire attributes which have loaded
for key in self.expired_attributes.intersection(keys):
if key in self.dict:
self.expired_attributes.remove(key)
self.callables.pop(key, None)
-
def commit_all(self):
"""commit all attributes unconditionally.
- This is used after a flush() or a regular instance load or refresh operation
- to mark committed all populated attributes.
+ This is used after a flush() or a full load/refresh
+ to remove all pending state from the instance.
+
+ - all attributes are marked as "committed"
+ - the "strong dirty reference" is removed
+ - the "modified" flag is set to False
+ - any "expired" markers/callables are removed.
Attributes marked as "expired" can potentially remain "expired" after this step
if a value was not populated in state.dict.
+
"""
-
self.committed_state = {}
- self.modified = False
- self.pending = {}
- self.appenders = {}
-
+
# unexpire attributes which have loaded
- for key in list(self.expired_attributes):
- if key in self.dict:
- self.expired_attributes.remove(key)
+ if self.expired_attributes:
+ for key in self.expired_attributes.intersection(self.dict):
self.callables.pop(key, None)
+ self.expired_attributes.difference_update(self.dict)
+
+ for key in self.manager.mutable_attributes:
+ if key in self.dict:
+ self.manager[key].impl.commit_to_state(self, self.committed_state)
- if self.class_._class_state.has_mutable_scalars:
- for attr in _managed_attributes(self.class_):
- if hasattr(attr.impl, 'commit_to_state') and attr.impl.key in self.dict:
- attr.impl.commit_to_state(self, self.dict[attr.impl.key])
-
- # remove strong ref
+ self.modified = self.expired = False
self._strong_obj = None
-class WeakInstanceDict(UserDict.UserDict):
- """similar to WeakValueDictionary, but wired towards 'state' objects."""
+class Events(object):
+ def __init__(self):
+ self.original_init = object.__init__
+ self.on_init = ()
+ self.on_init_failure = ()
+ self.on_load = ()
+
+ def run(self, event, *args, **kwargs):
+ for fn in getattr(self, event):
+ fn(*args, **kwargs)
+
+ def add_listener(self, event, listener):
+ # not thread safe... problem?
+ bucket = getattr(self, event)
+ if bucket == ():
+ setattr(self, event, [listener])
+ else:
+ bucket.append(listener)
- def __init__(self, *args, **kw):
- self._wr = weakref.ref(self)
- # RLock because the mutex is used by a cleanup handler, which can be
- # called at any time (including within an already mutexed block)
- self._mutex = util.threading.RLock()
- UserDict.UserDict.__init__(self, *args, **kw)
+ def remove_listener(self, event, listener):
+ bucket = getattr(self, event)
+ bucket.remove(listener)
- def __getitem__(self, key):
- state = self.data[key]
- o = state.obj()
- if o is None:
- o = state._check_resurrect(self)
- if o is None:
- raise KeyError, key
- return o
- def __contains__(self, key):
- try:
- state = self.data[key]
- o = state.obj()
- if o is None:
- o = state._check_resurrect(self)
- except KeyError:
- return False
- return o is not None
+class ClassManager(dict):
+ """tracks state information at the class level."""
- def has_key(self, key):
- return key in self
+ MANAGER_ATTR = '_fooclass_manager'
+ STATE_ATTR = '_foostate'
- def __repr__(self):
- return "<InstanceDict at %s>" % id(self)
+ event_registry_factory = Events
+ instance_state_factory = InstanceState
- def __setitem__(self, key, value):
- if key in self.data:
- self._mutex.acquire()
- try:
- if key in self.data:
- self.data[key].instance_dict = None
- finally:
- self._mutex.release()
- self.data[key] = value._state
- value._state.instance_dict = self._wr
-
- def __delitem__(self, key):
- state = self.data[key]
- state.instance_dict = None
- del self.data[key]
-
- def get(self, key, default=None):
- try:
- state = self.data[key]
- except KeyError:
- return default
+ def __init__(self, class_):
+ self.class_ = class_
+ self.factory = None # where we came from, for inheritance bookkeeping
+ self.info = {}
+ self.mappers = {}
+ self.mutable_attributes = util.Set()
+ self.local_attrs = {}
+ self.originals = {}
+ for base in class_.__mro__[-2:0:-1]: # reverse, skipping 1st and last
+ cls_state = manager_of_class(base)
+ if cls_state:
+ self.update(cls_state)
+ self.registered = False
+ self._instantiable = False
+ self.events = self.event_registry_factory()
+
+ def instantiable(self, boolean):
+ # experiment, probably won't stay in this form
+ assert boolean ^ self._instantiable, (boolean, self._instantiable)
+ if boolean:
+ self.events.original_init = self.class_.__init__
+ new_init = _generate_init(self.class_, self)
+ self.install_member('__init__', new_init)
else:
- o = state.obj()
- if o is None:
- # This should only happen
- return default
- else:
- return o
-
- def items(self):
- L = []
- for key, state in self.data.items():
- o = state.obj()
- if o is not None:
- L.append((key, o))
- return L
-
- def iteritems(self):
- for state in self.data.itervalues():
- value = state.obj()
- if value is not None:
- yield value._instance_key, value
+ self.uninstall_member('__init__')
+ self._instantiable = bool(boolean)
+ instantiable = property(lambda s: s._instantiable, instantiable)
- def iterkeys(self):
- return self.data.iterkeys()
+ def manage(self):
+ """Mark this instance as the manager for its class."""
+ setattr(self.class_, self.MANAGER_ATTR, self)
- def __iter__(self):
- return self.data.iterkeys()
+ def dispose(self):
+ """Dissasociate this instance from its class."""
+ delattr(self.class_, self.MANAGER_ATTR)
- def __len__(self):
- return len(self.values())
+ def manager_getter(self):
+ return attrgetter(self.MANAGER_ATTR)
- def itervalues(self):
- for state in self.data.itervalues():
- instance = state.obj()
- if instance is not None:
- yield instance
+ def instrument_attribute(self, key, inst, propagated=False):
+ if propagated:
+ if key in self.local_attrs:
+ return # don't override local attr with inherited attr
+ else:
+ self.local_attrs[key] = inst
+ self.install_descriptor(key, inst)
+ self[key] = inst
+ for cls in self.class_.__subclasses__():
+ manager = manager_of_class(cls)
+ if manager is None:
+ manager = create_manager_for_cls(cls)
+ manager.instrument_attribute(key, inst, True)
+
+ def uninstrument_attribute(self, key, propagated=False):
+ if key not in self:
+ return
+ if propagated:
+ if key in self.local_attrs:
+ return # don't get rid of local attr
+ else:
+ del self.local_attrs[key]
+ self.uninstall_descriptor(key)
+ del self[key]
+ if key in self.mutable_attributes:
+ self.mutable_attributes.remove(key)
+ for cls in self.class_.__subclasses__():
+ manager = manager_of_class(cls)
+ if manager is None:
+ manager = create_manager_for_cls(cls)
+ manager.uninstrument_attribute(key, True)
+
+ def unregister(self):
+ for key in list(self):
+ if key in self.local_attrs:
+ self.uninstrument_attribute(key)
+ self.registered = False
+
+ def install_descriptor(self, key, inst):
+ if key in (self.STATE_ATTR, self.MANAGER_ATTR):
+ raise KeyError("%r: requested attribute name conflicts with "
+ "instrumentation attribute of the same name." % key)
+ setattr(self.class_, key, inst)
+
+ def uninstall_descriptor(self, key):
+ delattr(self.class_, key)
+
+ def install_member(self, key, implementation):
+ if key in (self.STATE_ATTR, self.MANAGER_ATTR):
+ raise KeyError("%r: requested attribute name conflicts with "
+ "instrumentation attribute of the same name." % key)
+ self.originals.setdefault(key, getattr(self.class_, key, None))
+ setattr(self.class_, key, implementation)
+
+ def uninstall_member(self, key):
+ original = self.originals.pop(key, None)
+ if original is not None:
+ setattr(self.class_, key, original)
+
+ def instrument_collection_class(self, key, collection_class):
+ return collections.prepare_instrumentation(collection_class)
+
+ def initialize_collection(self, key, state, factory):
+ user_data = factory()
+ adapter = collections.CollectionAdapter(
+ self.get_impl(key), state, user_data)
+ return adapter, user_data
+
+ def is_instrumented(self, key, search=False):
+ if search:
+ return key in self
+ else:
+ return key in self.local_attrs
- def values(self):
- L = []
- for state in self.data.values():
- o = state.obj()
- if o is not None:
- L.append(o)
- return L
+ def get_impl(self, key):
+ return self[key].impl
- def popitem(self):
- raise NotImplementedError()
+ get_inst = dict.__getitem__
- def pop(self, key, *args):
- raise NotImplementedError()
+ def attributes(self):
+ return self.itervalues()
+ attributes = property(attributes)
- def setdefault(self, key, default=None):
- raise NotImplementedError()
+ def deferred_scalar_loader(cls, state, keys):
+ """TODO"""
+ deferred_scalar_loader = classmethod(deferred_scalar_loader)
- def update(self, dict=None, **kwargs):
- raise NotImplementedError()
+ ## InstanceState management
- def copy(self):
- raise NotImplementedError()
+ def new_instance(self, state=None):
+ instance = self.class_.__new__(self.class_)
+ self.setup_instance(instance, state)
+ return instance
+
+ def setup_instance(self, instance, with_state=None):
+ """Register an InstanceState with an instance."""
+ if self.has_state(instance):
+ state = self.state_of(instance)
+ if with_state:
+ assert state is with_state
+ return state
+ if with_state is None:
+ with_state = self.instance_state_factory(instance, self)
+ self.install_state(instance, with_state)
+ return with_state
+
+ def install_state(self, instance, state):
+ setattr(instance, self.STATE_ATTR, state)
+
+ def has_state(self, instance):
+ """True if an InstanceState is installed on the instance."""
+ return bool(getattr(instance, self.STATE_ATTR, False))
+
+ def state_of(self, instance):
+ """Retrieve the InstanceState of an instance.
+
+ May raise KeyError or AttributeError if no state is available.
+ """
+ return getattr(instance, self.STATE_ATTR)
- def all_states(self):
- return self.data.values()
+ def state_getter(self):
+ """Return a (instance) -> InstanceState callable.
-class StrongInstanceDict(dict):
- def all_states(self):
- return [o._state for o in self.values()]
+ "state getter" callables should raise either KeyError or
+ AttributeError if no InstanceState could be found for the
+ instance.
+ """
+ return attrgetter(self.STATE_ATTR)
-def _create_history(attr, state, current):
- original = state.committed_state.get(attr.key, NEVER_SET)
+ def _new_state_if_none(self, instance):
+ """Install a default InstanceState if none is present.
- if hasattr(attr, 'get_collection'):
- current = attr.get_collection(state, current)
- if original is NO_VALUE:
- return (list(current), [], [])
- elif original is NEVER_SET:
- return ([], list(current), [])
+ A private convenience method used by the __init__ decorator.
+ """
+ if self.has_state(instance):
+ return False
else:
- collection = util.OrderedIdentitySet(current)
- s = util.OrderedIdentitySet(original)
- return (list(collection.difference(s)), list(collection.intersection(s)), list(s.difference(collection)))
- else:
- if current is NO_VALUE:
- if original not in [None, NEVER_SET, NO_VALUE]:
- deleted = [original]
+ new_state = self.instance_state_factory(instance, self)
+ self.install_state(instance, new_state)
+ return new_state
+
+ def has_parent(self, state, key, optimistic=False):
+ """TODO"""
+ return self.get_impl(key).hasparent(state, optimistic=optimistic)
+
+ def __nonzero__(self):
+ """All ClassManagers are non-zero regardless of attribute state."""
+ return True
+
+ def __repr__(self):
+ return '<%s of %r at %x>' % (
+ self.__class__.__name__, self.class_, id(self))
+
+class _ClassInstrumentationAdapter(ClassManager):
+ """Adapts a user-defined InstrumentationManager to a ClassManager."""
+
+ def __init__(self, class_, override):
+ ClassManager.__init__(self, class_)
+ self._adapted = override
+
+ def manage(self):
+ self._adapted.manage(self.class_, self)
+
+ def dispose(self):
+ self._adapted.dispose(self.class_)
+
+ def manager_getter(self):
+ return self._adapted.manager_getter(self.class_)
+
+ def instrument_attribute(self, key, inst, propagated=False):
+ ClassManager.instrument_attribute(self, key, inst, propagated)
+ if not propagated:
+ self._adapted.instrument_attribute(self.class_, key, inst)
+
+ def install_descriptor(self, key, inst):
+ self._adapted.install_descriptor(self.class_, key, inst)
+
+ def uninstall_descriptor(self, key):
+ self._adapted.uninstall_descriptor(self.class_, key)
+
+ def install_member(self, key, implementation):
+ self._adapted.install_member(self.class_, key, implementation)
+
+ def uninstall_member(self, key):
+ self._adapted.uninstall_member(self.class_, key)
+
+ def instrument_collection_class(self, key, collection_class):
+ return self._adapted.instrument_collection_class(
+ self.class_, key, collection_class)
+
+ def initialize_collection(self, key, state, factory):
+ delegate = getattr(self._adapted, 'initialize_collection', None)
+ if delegate:
+ return delegate(key, state, factory)
+ else:
+ return ClassManager.initialize_collection(self, key, state, factory)
+
+ def setup_instance(self, instance, state=None):
+ self._adapted.initialize_instance_dict(self.class_, instance)
+ state = ClassManager.setup_instance(self, instance, with_state=state)
+ state.dict = self._adapted.get_instance_dict(self.class_, instance)
+ return state
+
+ def install_state(self, instance, state):
+ self._adapted.install_state(self.class_, instance, state)
+
+ def state_of(self, instance):
+ if hasattr(self._adapted, 'state_of'):
+ return self._adapted.state_of(self.class_, instance)
+ else:
+ getter = self._adapted.state_getter(self.class_)
+ return getter(instance)
+
+ def has_state(self, instance):
+ if hasattr(self._adapted, 'has_state'):
+ return self._adapted.has_state(self.class_, instance)
+ else:
+ try:
+ state = self.state_of(instance)
+ return True
+ except (KeyError, AttributeError):
+ return False
+
+ def state_getter(self):
+ return self._adapted.state_getter(self.class_)
+
+
+class History(tuple):
+ # TODO: migrate [] marker for empty slots to ()
+ __slots__ = ()
+
+ added = property(itemgetter(0))
+ unchanged = property(itemgetter(1))
+ deleted = property(itemgetter(2))
+
+ def __new__(cls, added, unchanged, deleted):
+ return tuple.__new__(cls, (added, unchanged, deleted))
+
+ def from_attribute(cls, attribute, state, current):
+ original = state.committed_state.get(attribute.key, NEVER_SET)
+
+ if hasattr(attribute, 'get_collection'):
+ current = attribute.get_collection(state, current)
+ if original is NO_VALUE:
+ return cls(list(current), [], [])
+ elif original is NEVER_SET:
+ return cls([], list(current), [])
else:
- deleted = []
- return ([], [], deleted)
- elif original is NO_VALUE:
- return ([current], [], [])
- elif original is NEVER_SET or attr.is_equal(current, original) is True: # dont let ClauseElement expressions here trip things up
- return ([], [current], [])
+ collection = util.OrderedIdentitySet(current)
+ s = util.OrderedIdentitySet(original)
+ return cls(list(collection.difference(s)),
+ list(collection.intersection(s)),
+ list(s.difference(collection)))
else:
- if original is not None:
- deleted = [original]
+ if current is NO_VALUE:
+ if original not in [None, NEVER_SET, NO_VALUE]:
+ deleted = [original]
+ else:
+ deleted = []
+ return cls([], [], deleted)
+ elif original is NO_VALUE:
+ return cls([current], [], [])
+ elif (original is NEVER_SET or
+ attribute.is_equal(current, original) is True):
+ # dont let ClauseElement expressions here trip things up
+ return cls([], [current], [])
else:
- deleted = []
- return ([current], [], deleted)
+ if original is not None:
+ deleted = [original]
+ else:
+ deleted = []
+ return cls([current], [], deleted)
+ from_attribute = classmethod(from_attribute)
+
class PendingCollection(object):
"""stores items appended and removed from a collection that has not been loaded yet.
When the collection is loaded, the changes present in PendingCollection are applied
to produce the final result.
+
"""
-
def __init__(self):
self.deleted_items = util.IdentitySet()
self.added_items = util.OrderedIdentitySet()
self.added_items.remove(value)
self.deleted_items.add(value)
-def _managed_attributes(class_):
- """return all InstrumentedAttributes associated with the given class_ and its superclasses."""
-
- return chain(*[cl._class_state.attrs.values() for cl in class_.__mro__[:-1] if hasattr(cl, '_class_state')])
def get_history(state, key, **kwargs):
- return getattr(state.class_, key).impl.get_history(state, **kwargs)
+ return state.get_history(key, **kwargs)
-def get_as_list(state, key, passive=False):
- """return an InstanceState attribute as a list,
- regardless of it being a scalar or collection-based
- attribute.
- returns None if passive=True and the getter returns
- PASSIVE_NORESULT.
- """
+def has_parent(cls, obj, key, optimistic=False):
+ """TODO"""
+ manager = manager_of_class(cls)
+ state = instance_state(obj)
+ return manager.has_parent(state, key, optimistic)
- attr = getattr(state.class_, key).impl
- x = attr.get(state, passive=passive)
- if x is PASSIVE_NORESULT:
- return None
- elif hasattr(attr, 'get_collection'):
- return attr.get_collection(state, x, passive=passive)
- elif isinstance(x, list):
- return x
- else:
- return [x]
-
-def has_parent(class_, instance, key, optimistic=False):
- return getattr(class_, key).impl.hasparent(instance._state, optimistic=optimistic)
+def register_class(class_):
+ """TODO"""
+
+ # TODO: what's this function for ? why would I call this and not create_manager_for_cls ?
+
+ manager = manager_of_class(class_)
+ if manager is None:
+ manager = create_manager_for_cls(class_)
+ if not manager.instantiable:
+ manager.instantiable = True
-def _create_prop(class_, key, uselist, callable_, typecallable, useobject, mutable_scalars, impl_class, **kwargs):
- if impl_class:
- return impl_class(class_, key, typecallable, **kwargs)
- elif uselist:
- return CollectionAttributeImpl(class_, key, callable_, typecallable, **kwargs)
- elif useobject:
- return ScalarObjectAttributeImpl(class_, key, callable_,**kwargs)
- elif mutable_scalars:
- return MutableScalarAttributeImpl(class_, key, callable_, **kwargs)
- else:
- return ScalarAttributeImpl(class_, key, callable_, **kwargs)
+def unregister_class(class_):
+ """TODO"""
+ manager = manager_of_class(class_)
+ assert manager
+ assert manager.instantiable
+ manager.instantiable = False
+ manager.unregister()
-def manage(instance):
- """initialize an InstanceState on the given instance."""
+def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, mutable_scalars=False, impl_class=None, **kwargs):
- if not hasattr(instance, '_state'):
- instance._state = InstanceState(instance)
+ manager = manager_of_class(class_)
+ if manager.is_instrumented(key):
+ # this currently only occurs if two primary mappers are made for the
+ # same class. TODO: possibly have InstrumentedAttribute check
+ # "entity_name" when searching for impl. raise an error if two
+ # attrs attached simultaneously otherwise
+ return
-def new_instance(class_, state=None):
- """create a new instance of class_ without its __init__() method being called.
+ if uselist:
+ factory = kwargs.pop('typecallable', None)
+ typecallable = manager.instrument_collection_class(
+ key, factory or list)
+ else:
+ typecallable = kwargs.pop('typecallable', None)
- Also initializes an InstanceState on the new instance.
- """
+ comparator = kwargs.pop('comparator', None)
+ parententity = kwargs.pop('parententity', None)
- s = class_.__new__(class_)
- if state:
- s._state = state
+ if proxy_property:
+ proxy_type = proxied_attribute_factory(proxy_property)
+ descriptor = proxy_type(key, proxy_property, comparator, parententity)
else:
- s._state = InstanceState(s)
- return s
+ descriptor = InstrumentedAttribute(
+ _create_prop(class_, key, uselist, callable_,
+ class_manager=manager,
+ useobject=useobject,
+ typecallable=typecallable,
+ mutable_scalars=mutable_scalars,
+ impl_class=impl_class,
+ **kwargs),
+ comparator=comparator, parententity=parententity)
+
+ manager.instrument_attribute(key, descriptor)
-def _init_class_state(class_):
- if not '_class_state' in class_.__dict__:
- class_._class_state = ClassState()
+def unregister_attribute(class_, key):
+ manager_of_class(class_).uninstrument_attribute(key)
-def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_loader=None):
- _init_class_state(class_)
- class_._class_state.deferred_scalar_loader=deferred_scalar_loader
+def init_collection(state, key):
+ """Initialize a collection attribute and return the collection adapter."""
+ attr = state.get_impl(key)
+ user_data = attr.initialize(state)
+ return attr.get_collection(state, user_data)
- oldinit = None
- doinit = False
+def set_attribute(instance, key, value):
+ state = instance_state(instance)
+ state.get_impl(key).set(state, value, None)
- def init(instance, *args, **kwargs):
- if not hasattr(instance, '_state'):
- instance._state = InstanceState(instance)
+def get_attribute(instance, key):
+ state = instance_state(instance)
+ return state.get_impl(key).get(state)
- if extra_init:
- extra_init(class_, oldinit, instance, args, kwargs)
+def del_attribute(instance, key):
+ state = instance_state(instance)
+ state.get_impl(key).delete(state)
- try:
- if doinit:
- oldinit(instance, *args, **kwargs)
- elif args or kwargs:
- # simulate error message raised by object(), but don't copy
- # the text verbatim
- raise TypeError("default constructor for object() takes no parameters")
- except:
- if on_exception:
- on_exception(class_, oldinit, instance, args, kwargs)
- raise
+def is_instrumented(instance, key):
+ return manager_of_class(instance.__class__).is_instrumented(key, search=True)
+
+class InstrumentationRegistry(object):
+ """Private instrumentation registration singleton."""
+ manager_finders = weakref.WeakKeyDictionary()
+ state_finders = util.WeakIdentityMapping()
+ extended = False
- # override oldinit
- oldinit = class_.__init__
- if oldinit is None or not hasattr(oldinit, '_oldinit'):
- init._oldinit = oldinit
- class_.__init__ = init
- # if oldinit is already one of our 'init' methods, replace it
- elif hasattr(oldinit, '_oldinit'):
- init._oldinit = oldinit._oldinit
- class_.__init = init
- oldinit = oldinit._oldinit
+ def create_manager_for_cls(self, class_):
+ assert class_ is not None
+ assert manager_of_class(class_) is None
- if oldinit is not None:
- doinit = oldinit is not object.__init__
+ for finder in instrumentation_finders:
+ factory = finder(class_)
+ if factory is not None:
+ break
+ else:
+ factory = ClassManager
+
+ existing_factories = collect_management_factories_for(class_)
+ existing_factories.add(factory)
+ if len(existing_factories) > 1:
+ raise TypeError(
+ "multiple instrumentation implementations specified "
+ "in %s inheritance hierarchy: %r" % (
+ class_.__name__, list(existing_factories)))
+
+ manager = factory(class_)
+ if not isinstance(manager, ClassManager):
+ manager = _ClassInstrumentationAdapter(class_, manager)
+ if factory != ClassManager and not self.extended:
+ self.extended = True
+ _install_lookup_strategy(self)
+
+ manager.factory = factory
+ manager.manage()
+ self.manager_finders[class_] = manager.manager_getter()
+ self.state_finders[class_] = manager.state_getter()
+ return manager
+
+ def manager_of_class(self, cls):
+ if cls is None:
+ return None
try:
- init.__name__ = oldinit.__name__
- init.__doc__ = oldinit.__doc__
- except:
- # cant set __name__ in py 2.3 !
- pass
+ finder = self.manager_finders[cls]
+ except KeyError:
+ return None
+ else:
+ return finder(cls)
-def unregister_class(class_):
- if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
- if class_.__init__._oldinit is not None:
- class_.__init__ = class_.__init__._oldinit
+ def state_of(self, instance):
+ if instance is None:
+ raise AttributeError("None has no persistent state.")
+ return self.state_finders[instance.__class__](instance)
+
+ def state_or_default(self, instance, default=None):
+ if instance is None:
+ return default
+ try:
+ finder = self.state_finders[instance.__class__]
+ except KeyError:
+ return default
else:
- delattr(class_, '__init__')
+ try:
+ return finder(instance)
+ except (KeyError, AttributeError):
+ return default
+ except:
+ raise
+
+ def unregister(self, class_):
+ if class_ in self.manager_finders:
+ manager = self.manager_of_class(class_)
+ manager.dispose()
+ del self.manager_finders[class_]
+ del self.state_finders[class_]
+
+# Create a registry singleton and prepare placeholders for lookup functions.
+
+instrumentation_registry = InstrumentationRegistry()
+create_manager_for_cls = None
+manager_of_class = None
+instance_state = None
+_lookup_strategy = None
+
+def _install_lookup_strategy(implementation):
+ """Switch between native and extended instrumentation modes.
- if '_class_state' in class_.__dict__:
- _class_state = class_.__dict__['_class_state']
- for key, attr in _class_state.attrs.iteritems():
- if key in class_.__dict__:
- delattr(class_, attr.impl.key)
- delattr(class_, '_class_state')
+ Completely private. Use the instrumentation_finders interface to
+ inject global instrumentation behavior.
-def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, mutable_scalars=False, impl_class=None, **kwargs):
- _init_class_state(class_)
+ """
+ global manager_of_class, instance_state, create_manager_for_cls
+ global _lookup_strategy
+
+ # Using a symbol here to make debugging a little friendlier.
+ if implementation is not util.symbol('native'):
+ manager_of_class = implementation.manager_of_class
+ instance_state = implementation.state_of
+ create_manager_for_cls = implementation.create_manager_for_cls
+ else:
+ def manager_of_class(class_):
+ return getattr(class_, ClassManager.MANAGER_ATTR, None)
+ manager_of_class = instrumentation_registry.manager_of_class
+ instance_state = attrgetter(ClassManager.STATE_ATTR)
+ create_manager_for_cls = instrumentation_registry.create_manager_for_cls
+ # TODO: maybe log an event when setting a strategy.
+ _lookup_strategy = implementation
- typecallable = kwargs.pop('typecallable', None)
- if isinstance(typecallable, InstrumentedAttribute):
- typecallable = None
- comparator = kwargs.pop('comparator', None)
+_install_lookup_strategy(util.symbol('native'))
- if key in class_.__dict__ and isinstance(class_.__dict__[key], InstrumentedAttribute):
- # this currently only occurs if two primary mappers are made for the same class.
- # TODO: possibly have InstrumentedAttribute check "entity_name" when searching for impl.
- # raise an error if two attrs attached simultaneously otherwise
- return
+def find_native_user_instrumentation_hook(cls):
+ """Find user-specified instrumentation management for a class."""
+ return getattr(cls, INSTRUMENTATION_MANAGER, None)
+instrumentation_finders.append(find_native_user_instrumentation_hook)
- if proxy_property:
- proxy_type = proxied_attribute_factory(proxy_property)
- inst = proxy_type(key, proxy_property, comparator)
- else:
- inst = InstrumentedAttribute(_create_prop(class_, key, uselist, callable_, useobject=useobject,
- typecallable=typecallable, mutable_scalars=mutable_scalars, impl_class=impl_class, **kwargs), comparator=comparator)
+def collect_management_factories_for(cls):
+ """Return a collection of factories in play or specified for a hierarchy.
- setattr(class_, key, inst)
- class_._class_state.attrs[key] = inst
+ Traverses the entire inheritance graph of a cls and returns a collection
+ of instrumentation factories for those classes. Factories are extracted
+ from active ClassManagers, if available, otherwise
+ instrumentation_finders is consulted.
-def unregister_attribute(class_, key):
- class_state = class_._class_state
- if key in class_state.attrs:
- del class_._class_state.attrs[key]
- delattr(class_, key)
+ """
+ hierarchy = util.class_hierarchy(cls)
+ factories = util.Set()
+ for member in hierarchy:
+ manager = manager_of_class(member)
+ if manager is not None:
+ factories.add(manager.factory)
+ else:
+ for finder in instrumentation_finders:
+ factory = finder(member)
+ if factory is not None:
+ break
+ else:
+ factory = None
+ factories.add(factory)
+ factories.discard(None)
+ return factories
+
+
+def _create_prop(class_, key, uselist, callable_, class_manager, typecallable, useobject, mutable_scalars, impl_class, **kwargs):
+ if impl_class:
+ return impl_class(class_, key, typecallable, class_manager=class_manager, **kwargs)
+ elif uselist:
+ return CollectionAttributeImpl(class_, key, callable_,
+ typecallable=typecallable,
+ class_manager=class_manager, **kwargs)
+ elif useobject:
+ return ScalarObjectAttributeImpl(class_, key, callable_,
+ class_manager=class_manager, **kwargs)
+ elif mutable_scalars:
+ return MutableScalarAttributeImpl(class_, key, callable_,
+ class_manager=class_manager, **kwargs)
+ else:
+ return ScalarAttributeImpl(class_, key, callable_,
+ class_manager=class_manager, **kwargs)
+
+def _generate_init(class_, class_manager):
+ """Build an __init__ decorator that triggers ClassManager events."""
+
+ original__init__ = class_.__init__
+ assert original__init__
+
+ # Go through some effort here and don't change the user's __init__
+ # calling signature.
+ # FIXME: need to juggle local names to avoid constructor argument
+ # clashes.
+ func_body = """\
+def __init__(%(args)s):
+ new_state = class_manager._new_state_if_none(%(self_arg)s)
+ if new_state:
+ return new_state.initialize_instance(%(apply_kw)s)
+ else:
+ return original__init__(%(apply_kw)s)
+"""
+ func_vars = util.format_argspec_init(original__init__, grouped=False)
+ func_text = func_body % func_vars
+ #TODO: log debug #print func_text
+
+ env = locals().copy()
+ exec func_text in env
+ __init__ = env['__init__']
+ __init__.__doc__ = original__init__.__doc__
+ return __init__
-def init_collection(instance, key):
- """Initialize a collection attribute and return the collection adapter."""
- attr = getattr(instance.__class__, key).impl
- state = instance._state
- user_data = attr.initialize(state)
- return attr.get_collection(state, user_data)
The owning object and InstrumentedCollectionAttribute are also reachable
through the adapter, allowing for some very sophisticated behavior.
+
"""
import copy
import sys
import weakref
-from sqlalchemy import exceptions, schema, util as sautil
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import schema
+import sqlalchemy.util as sautil
from sqlalchemy.util import attrgetter, Set
'mapped_collection', 'column_mapped_collection',
'attribute_mapped_collection']
+__instrumentation_mutex = sautil.threading.Lock()
+
+
def column_mapped_collection(mapping_spec):
"""A dictionary-based collection type with column-based keying.
can not, for example, map on foreign key values if those key values will
change during the session, i.e. from None to a database-assigned integer
after a session flush.
- """
- from sqlalchemy.orm import object_mapper
+ """
+ from sqlalchemy.orm.util import _state_mapper
+ from sqlalchemy.orm.attributes import instance_state
if isinstance(mapping_spec, schema.Column):
def keyfunc(value):
- m = object_mapper(value)
- return m._get_attr_by_column(value, mapping_spec)
+ state = instance_state(value)
+ m = _state_mapper(state)
+ return m._get_state_attr_by_column(state, mapping_spec)
else:
cols = []
for c in mapping_spec:
if not isinstance(c, schema.Column):
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"mapping_spec tuple may only contain columns")
cols.append(c)
mapping_spec = tuple(cols)
def keyfunc(value):
- m = object_mapper(value)
- return tuple([m._get_attr_by_column(value, c) for c in mapping_spec])
+ state = instance_state(value)
+ m = _state_mapper(state)
+ return tuple([m._get_state_attr_by_column(state, c)
+ for c in mapping_spec])
return lambda: MappedCollection(keyfunc)
def attribute_mapped_collection(attr_name):
can not, for example, map on foreign key values if those key values will
change during the session, i.e. from None to a database-assigned integer
after a session flush.
- """
+ """
return lambda: MappedCollection(attrgetter(attr_name))
can not, for example, map on foreign key values if those key values will
change during the session, i.e. from None to a database-assigned integer
after a session flush.
- """
+ """
return lambda: MappedCollection(keyfunc)
class collection(object):
Decorators can be specified in long-hand for Python 2.3, or with
the class-level dict attribute '__instrumentation__'- see the source
for details.
- """
+ """
# Bundled as a class solely for ease of use: packaging, doc strings,
# importability.
If the appender method is internally instrumented, you must also
receive the keyword argument '_sa_initiator' and ensure its
promulgation to collection events.
- """
+ """
setattr(fn, '_sa_instrument_role', 'appender')
return fn
appender = classmethod(appender)
If the remove method is internally instrumented, you must also
receive the keyword argument '_sa_initiator' and ensure its
promulgation to collection events.
- """
+ """
setattr(fn, '_sa_instrument_role', 'remover')
return fn
remover = classmethod(remover)
@collection.iterator
def __iter__(self): ...
- """
+ """
setattr(fn, '_sa_instrument_role', 'iterator')
return fn
iterator = classmethod(iterator)
# never be called, unless:
@collection.internally_instrumented
def extend(self, items): ...
- """
+ """
setattr(fn, '_sa_instrumented', True)
return fn
internally_instrumented = classmethod(internally_instrumented)
invoked immediately after the '_sa_adapter' property is set on
the instance. A single argument is passed: the collection adapter
that has been linked, or None if unlinking.
- """
+ """
setattr(fn, '_sa_instrument_role', 'on_link')
return fn
on_link = classmethod(on_link)
Supply an implementation of this method if you want to expand the
range of possible types that can be assigned in bulk or perform
validation on the values about to be assigned.
- """
+ """
setattr(fn, '_sa_instrument_role', 'converter')
return fn
converter = classmethod(converter)
@collection.adds('entity')
def do_stuff(self, thing, entity=None): ...
- """
+ """
def decorator(fn):
setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
return fn
@collection.replaces(2)
def __setitem__(self, index, item): ...
- """
+ """
def decorator(fn):
setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
setattr(fn, '_sa_instrument_after', 'fire_remove_event')
For methods where the value to remove is not known at call-time, use
collection.removes_return.
- """
+ """
def decorator(fn):
setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg))
return fn
For methods where the value to remove is known at call-time, use
collection.remove.
- """
+ """
def decorator(fn):
setattr(fn, '_sa_instrument_after', 'fire_remove_event')
return fn
# implementations
def collection_adapter(collection):
"""Fetch the CollectionAdapter for a collection."""
-
return getattr(collection, '_sa_adapter', None)
def collection_iter(collection):
If the collection is an ORM collection, it need not be attached to an
object to be iterable.
- """
+ """
try:
return getattr(collection, '_sa_iterator',
getattr(collection, '__iter__'))()
The ORM uses an CollectionAdapter exclusively for interaction with
entity collections.
- """
+ """
def __init__(self, attr, owner_state, data):
self.attr = attr
self._data = weakref.ref(data)
def link_to_self(self, data):
"""Link a collection to this adapter, and fire a link event."""
-
setattr(data, '_sa_adapter', self)
if hasattr(data, '_sa_on_link'):
getattr(data, '_sa_on_link')(self)
def unlink(self, data):
"""Unlink a collection from any adapter, and fire a link event."""
-
setattr(data, '_sa_adapter', None)
if hasattr(data, '_sa_on_link'):
getattr(data, '_sa_on_link')(None)
If a converter implementation is not supplied on the collection,
a default duck-typing-based implementation is used.
- """
+ """
converter = getattr(self._data(), '_sa_converter', None)
if converter is not None:
return converter(obj)
def append_with_event(self, item, initiator=None):
"""Add an entity to the collection, firing mutation events."""
-
getattr(self._data(), '_sa_appender')(item, _sa_initiator=initiator)
def append_without_event(self, item):
"""Add or restore an entity to the collection, firing no events."""
-
getattr(self._data(), '_sa_appender')(item, _sa_initiator=False)
def remove_with_event(self, item, initiator=None):
"""Remove an entity from the collection, firing mutation events."""
-
getattr(self._data(), '_sa_remover')(item, _sa_initiator=initiator)
def remove_without_event(self, item):
"""Remove an entity from the collection, firing no events."""
-
getattr(self._data(), '_sa_remover')(item, _sa_initiator=False)
def clear_with_event(self, initiator=None):
"""Empty the collection, firing a mutation event for each entity."""
-
for item in list(self):
self.remove_with_event(item, initiator)
def clear_without_event(self):
"""Empty the collection, firing no events."""
-
for item in list(self):
self.remove_without_event(item)
def __iter__(self):
"""Iterate over entities in the collection."""
-
return getattr(self._data(), '_sa_iterator')()
def __len__(self):
"""Count entities in the collection."""
-
return len(list(getattr(self._data(), '_sa_iterator')()))
def __nonzero__(self):
Initiator is the InstrumentedAttribute that initiated the membership
mutation, and should be left as None unless you are passing along
an initiator value from a chained operation.
- """
+ """
if initiator is not False and item is not None:
self.attr.fire_append_event(self.owner_state, item, initiator)
Initiator is the InstrumentedAttribute that initiated the membership
mutation, and should be left as None unless you are passing along
an initiator value from a chained operation.
- """
+ """
if initiator is not False and item is not None:
self.attr.fire_remove_event(self.owner_state, item, initiator)
Only called if the entity cannot be removed after calling
fire_remove_event().
- """
+ """
self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator)
def __getstate__(self):
for member in removals:
existing_adapter.remove_with_event(member)
-__instrumentation_mutex = sautil.threading.Lock()
-def _prepare_instrumentation(factory):
+def prepare_instrumentation(factory):
"""Prepare a callable for future use as a collection class factory.
Given a collection class factory (either a type or no-arg callable),
This function is responsible for converting collection_class=list
into the run-time behavior of collection_class=InstrumentedList.
- """
+ """
# Convert a builtin to 'Instrumented*'
if factory in __canned_instrumentation:
factory = __canned_instrumentation[factory]
Given a collection factory that returns a builtin type (e.g. a list),
return a wrapped function that converts that type to one of our
instrumented types.
- """
+ """
def wrapper():
collection = original_factory()
type_ = type(collection)
# collection
return __canned_instrumentation[type_](collection)
else:
- raise exceptions.InvalidRequestError(
+ raise sa_exc.InvalidRequestError(
"Collection class factories must produce instances of a "
"single class.")
try:
def _instrument_class(cls):
"""Modify methods in a class and install instrumentation."""
-
# FIXME: more formally document this as a decoratorless/Python 2.3
# option for specifying instrumentation. (likely doc'd here in code only,
# not in online docs.)
# types is transformed into one of our trivial subclasses
# (e.g. InstrumentedList). Catch anything else that sneaks in here...
if cls.__module__ == '__builtin__':
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"Can not instrument a built-in type. Use a "
"subclass, even a trivial one.")
# ensure all roles are present, and apply implicit instrumentation if
# needed
if 'appender' not in roles or not hasattr(cls, roles['appender']):
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"Type %s must elect an appender method to be "
"a collection class" % cls.__name__)
elif (roles['appender'] not in methods and
methods[roles['appender']] = ('fire_append_event', 1, None)
if 'remover' not in roles or not hasattr(cls, roles['remover']):
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"Type %s must elect a remover method to be "
"a collection class" % cls.__name__)
elif (roles['remover'] not in methods and
methods[roles['remover']] = ('fire_remove_event', 1, None)
if 'iterator' not in roles or not hasattr(cls, roles['iterator']):
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"Type %s must elect an iterator method to be "
"a collection class" % cls.__name__)
def _instrument_membership_mutator(method, before, argument, after):
"""Route method args and/or return value through the collection adapter."""
-
# This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))'
if before:
fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0]))
if before:
if pos_arg is None:
if named_arg not in kw:
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"Missing argument %s" % argument)
value = kw[named_arg]
else:
elif named_arg in kw:
value = kw[named_arg]
else:
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"Missing argument %s" % argument)
initiator = kw.pop('_sa_initiator', None)
def __set(collection, item, _sa_initiator=None):
"""Run set events, may eventually be inlined into decorators."""
-
if _sa_initiator is not False and item is not None:
executor = getattr(collection, '_sa_adapter', None)
if executor:
def __del(collection, item, _sa_initiator=None):
"""Run del events, may eventually be inlined into decorators."""
-
if _sa_initiator is not False and item is not None:
executor = getattr(collection, '_sa_adapter', None)
if executor:
def __before_delete(collection, _sa_initiator=None):
"""Special method to run 'commit existing value' methods"""
-
executor = getattr(collection, '_sa_adapter', None)
if executor:
getattr(executor, 'fire_pre_remove_event')(_sa_initiator)
def _list_decorators():
- """Hand-turned instrumentation wrappers that can decorate any list-like
- class."""
+ """Tailored instrumentation wrappers for any list-like class."""
def _tidy(fn):
setattr(fn, '_sa_instrumented', True)
return l
def _dict_decorators():
- """Hand-turned instrumentation wrappers that can decorate any dict-like
- mapping class."""
+ """Tailored instrumentation wrappers for any dict-like mapping class."""
def _tidy(fn):
setattr(fn, '_sa_instrumented', True)
fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__')
- Unspecified=sautil.symbol('Unspecified')
+ Unspecified = sautil.symbol('Unspecified')
def __setitem__(fn):
def __setitem__(self, key, value, _sa_initiator=None):
def _set_decorators():
- """Hand-turned instrumentation wrappers that can decorate any set-like
- sequence class."""
+ """Tailored instrumentation wrappers for any set-like class."""
def _tidy(fn):
setattr(fn, '_sa_instrumented', True)
fn.__doc__ = getattr(getattr(Set, fn.__name__), '__doc__')
- Unspecified=sautil.symbol('Unspecified')
+ Unspecified = sautil.symbol('Unspecified')
def add(fn):
def add(self, value, _sa_initiator=None):
``set`` and ``remove`` are implemented in terms of a keying function: any
callable that takes an object and returns an object for use as a dictionary
key.
+
"""
def __init__(self, keyfunc):
returns an object for use as a dictionary key.
The keyfunc will be called every time the ORM needs to add a member by
- value-only (such as when loading instances from the database) or remove
- a member. The usual cautions about dictionary keying apply-
+ value-only (such as when loading instances from the database) or
+ remove a member. The usual cautions about dictionary keying apply-
``keyfunc(object)`` should return the same output for the life of the
collection. Keying based on mutable properties can result in
unreachable instances "lost" in the collection.
+
"""
self.keyfunc = keyfunc
def set(self, value, _sa_initiator=None):
- """Add an item to the collection, with a key provided by this instance's keyfunc."""
+ """Add an item by value, consulting the keyfunc for the key."""
key = self.keyfunc(value)
self.__setitem__(key, value, _sa_initiator)
set = collection.appender(set)
def remove(self, value, _sa_initiator=None):
- """Remove an item from the collection by value, consulting this instance's keyfunc for the key."""
+ """Remove an item by value, consulting the keyfunc for the key."""
key = self.keyfunc(value)
# Let self[key] raise if key is not in this collection
# testlib.pragma exempt:__ne__
if self[key] != value:
- raise exceptions.InvalidRequestError(
+ raise sa_exc.InvalidRequestError(
"Can not remove '%s': collection holds '%s' for key '%s'. "
"Possible cause: is the MappedCollection key function "
"based on mutable properties or properties that only obtain "
Raises a TypeError if the key in any (key, value) pair in the dictlike
object does not match the key that this collection's keyfunc would
have assigned for that value.
- """
+ """
for incoming_key, value in sautil.dictlike_iteritems(dictlike):
new_key = self.keyfunc(value)
if incoming_key != new_key:
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Relationship dependencies.
-"""Bridge the ``PropertyLoader`` (i.e. a ``relation()``) and the
+Bridges the ``PropertyLoader`` (i.e. a ``relation()``) and the
``UOWTransaction`` together to allow processing of relation()-based
- dependencies at flush time.
+dependencies at flush time.
+
"""
-from sqlalchemy.orm import sync
-from sqlalchemy import sql, util, exceptions
+from sqlalchemy import sql, util
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy.orm import attributes, exc, sync
from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
MANYTOONE: ManyToOneDP,
MANYTOMANY : ManyToManyDP,
}
- if prop.association is not None:
- return AssociationDP(prop)
- else:
- return types[prop.direction](prop)
+ return types[prop.direction](prop)
class DependencyProcessor(object):
no_dependencies = False
self.parent = prop.parent
self.secondary = prop.secondary
self.direction = prop.direction
- self.is_backref = prop.is_backref
+ self.is_backref = prop._is_backref
self.post_update = prop.post_update
self.foreign_keys = prop.foreign_keys
self.passive_deletes = prop.passive_deletes
self.enable_typechecks = prop.enable_typechecks
self.key = prop.key
if not self.prop.synchronize_pairs:
- raise exceptions.ArgumentError("Can't build a DependencyProcessor for relation %s. No target attributes to populate between parent and child are present" % self.prop)
+ raise sa_exc.ArgumentError("Can't build a DependencyProcessor for relation %s. No target attributes to populate between parent and child are present" % self.prop)
def _get_instrumented_attribute(self):
"""Return the ``InstrumentedAttribute`` handled by this
``DependencyProecssor``.
"""
- return getattr(self.parent.class_, self.key)
+ return self.parent.class_manager.get_impl(self.key)
def hasparent(self, state):
"""return True if the given object instance has a parent,
according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``."""
# TODO: use correct API for this
- return self._get_instrumented_attribute().impl.hasparent(state)
+ return self._get_instrumented_attribute().hasparent(state)
def register_dependencies(self, uowcommit):
"""Tell a ``UOWTransaction`` what mappers are dependent on
if not self.enable_typechecks:
return
if state is not None and not self.mapper._canload(state):
- raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper))
+ raise exc.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper))
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
"""Called during a flush to synchronize primary key identifier
# the child objects have to have their foreign key to the parent set to NULL
# this phase can be called safely for any cascade but is unnecessary if delete cascade
# is on.
- if self.post_update or not self.passive_deletes=='all':
+ if self.post_update or not self.passive_deletes == 'all':
for state in deplist:
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if unchanged or deleted:
for child in deleted:
if child is not None and self.hasparent(child) is False:
# head object is being deleted, and we manage its list of child objects
# the child objects have to have their foreign key to the parent set to NULL
if not self.post_update:
- should_null_fks = not self.cascade.delete and not self.passive_deletes=='all'
+ should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all'
for state in deplist:
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if unchanged or deleted:
for child in deleted:
if child is not None and self.hasparent(child) is False:
uowcommit.register_object(child)
else:
for state in deplist:
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True)
if added or deleted:
for child in added:
if child is not None:
elif self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
- uowcommit.register_object(c._state, isdelete=True)
+ uowcommit.register_object(
+ attributes.instance_state(c),
+ isdelete=True)
if not self.passive_updates and self._pks_changed(uowcommit, state):
if not unchanged:
(added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=False)
for s in [elem for elem in uowcommit.session.identity_map.all_states()
if issubclass(elem.class_, self.parent.class_) and
self.key in elem.dict and
- elem.dict[self.key]._state in switchers
+ attributes.instance_state(elem.dict[self.key]) in switchers
]:
uowcommit.register_object(s, listonly=self.passive_updates)
- sync.populate(s.dict[self.key]._state, self.mapper, s, self.parent, self.prop.synchronize_pairs)
+ sync.populate(attributes.instance_state(s.dict[self.key]), self.mapper, s, self.parent, self.prop.synchronize_pairs)
#self.syncrules.execute(s.dict[self.key]._state, s, None, None, False)
def _pks_changed(self, uowcommit, state):
def process_dependencies(self, task, deplist, uowcommit, delete = False):
#print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
if delete:
- if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes=='all':
+ if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all':
# post_update means we have to update our row to not reference the child object
# before we can DELETE the row
for state in deplist:
self._synchronize(state, None, None, True, uowcommit)
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if added or unchanged or deleted:
self._conditional_post_update(state, uowcommit, deleted + unchanged + added)
else:
for state in deplist:
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True)
if added or deleted or unchanged:
for child in added:
self._synchronize(state, child, None, False, uowcommit)
if delete:
if self.cascade.delete or self.cascade.delete_orphan:
for state in deplist:
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if self.cascade.delete_orphan:
todelete = added + unchanged + deleted
else:
continue
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
- uowcommit.register_object(c._state, isdelete=True)
+ uowcommit.register_object(
+ attributes.instance_state(c), isdelete=True)
else:
for state in deplist:
uowcommit.register_object(state)
if self.cascade.delete_orphan:
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if deleted:
for child in deleted:
if self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
- uowcommit.register_object(c._state, isdelete=True)
+ uowcommit.register_object(
+ attributes.instance_state(c),
+ isdelete=True)
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
if delete:
for state in deplist:
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if deleted or unchanged:
for child in deleted + unchanged:
if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes):
statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow]))
result = connection.execute(statement, secondary_delete)
if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete):
- raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of secondary table rows deleted from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_delete)))
+ raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of secondary table rows deleted from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_delete)))
if secondary_update:
statement = self.secondary.update(sql.and_(*[c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow]))
result = connection.execute(statement, secondary_update)
if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update):
- raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of secondary table rows updated from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_update)))
+ raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of secondary table rows updated from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_update)))
if secondary_insert:
statement = self.secondary.insert()
#print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
if not delete:
for state in deplist:
- (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True)
if deleted:
for child in deleted:
if self.cascade.delete_orphan and self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
- uowcommit.register_object(c._state, isdelete=True)
+ uowcommit.register_object(
+ attributes.instance_state(c), isdelete=True)
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
if associationrow is None:
def _pks_changed(self, uowcommit, state):
return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
-class AssociationDP(OneToManyDP):
- def __init__(self, *args, **kwargs):
- super(AssociationDP, self).__init__(*args, **kwargs)
- self.cascade.delete = True
- self.cascade.delete_orphan = True
-
class MapperStub(object):
"""Pose as a Mapper representing the association table in a
many-to-many join, when performing a ``flush()``.
-"""'dynamic' collection API. returns Query() objects on the 'read' side, alters
-a special AttributeHistory on the 'write' side."""
+# dynamic.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import exceptions, util, logging
-from sqlalchemy.orm import attributes, object_session, util as mapperutil, strategies
+"""Dynamic collection API.
+
+Dynamic collections act like Query() objects for read operations and support
+basic add/delete mutation.
+
+"""
+
+from sqlalchemy import log, util
+import sqlalchemy.exceptions as sa_exc
+
+from sqlalchemy.orm import attributes, object_session, \
+ util as mapperutil, strategies
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.mapper import has_identity, object_mapper
self.is_class_level = True
self._register_attribute(self.parent.class_, impl_class=DynamicAttributeImpl, target_mapper=self.parent_property.mapper, order_by=self.parent_property.order_by)
- def create_row_processor(self, selectcontext, mapper, row):
- return (None, None, None)
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ return (None, None)
-DynaLoader.logger = logging.class_logger(DynaLoader)
+DynaLoader.logger = log.class_logger(DynaLoader)
class DynamicAttributeImpl(attributes.AttributeImpl):
- def __init__(self, class_, key, typecallable, target_mapper, order_by, **kwargs):
- super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs)
+ uses_objects = True
+ accepts_scalar_loader = False
+
+ def __init__(self, class_, key, typecallable, class_manager, target_mapper, order_by, **kwargs):
+ super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, class_manager, **kwargs)
self.target_mapper = target_mapper
- self.order_by=order_by
+ self.order_by = order_by
self.query_class = AppenderQuery
def get(self, state, passive=False):
state.modified = True
if self.trackparent and value is not None:
- self.sethasparent(value._state, True)
- instance = state.obj()
+ self.sethasparent(attributes.instance_state(value), True)
for ext in self.extensions:
- ext.append(instance, value, initiator or self)
+ ext.append(state, value, initiator or self)
def fire_remove_event(self, state, value, initiator):
state.modified = True
if self.trackparent and value is not None:
- self.sethasparent(value._state, False)
+ self.sethasparent(attributes.instance_state(value), False)
- instance = state.obj()
for ext in self.extensions:
- ext.remove(instance, value, initiator or self)
+ ext.remove(state, value, initiator or self)
def set(self, state, value, initiator):
if initiator is self:
def session(self):
return self.__session()
- session = property(session)
+ session = property(session, lambda s, x:None)
def __iter__(self):
sess = self.__session()
if sess is None:
- return iter(self.attr._get_collection_history(self.instance._state, passive=True).added_items)
+ return iter(self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ passive=True).added_items)
else:
return iter(self._clone(sess))
def __getitem__(self, index):
sess = self.__session()
if sess is None:
- return self.attr._get_collection_history(self.instance._state, passive=True).added_items.__getitem__(index)
+ return self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ passive=True).added_items.__getitem__(index)
else:
return self._clone(sess).__getitem__(index)
def count(self):
sess = self.__session()
if sess is None:
- return len(self.attr._get_collection_history(self.instance._state, passive=True).added_items)
+ return len(self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ passive=True).added_items)
else:
return self._clone(sess).count()
if sess is None:
sess = object_session(instance)
if sess is None:
- try:
- sess = object_mapper(instance).get_session()
- except exceptions.InvalidRequestError:
- raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.instance_str(instance), self.attr.key))
+ raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.instance_str(instance), self.attr.key))
q = sess.query(self.attr.target_mapper).with_parent(instance, self.attr.key)
if self.attr.order_by:
oldlist = list(self)
else:
oldlist = []
- self.attr._get_collection_history(self.instance._state, passive=True).replace(oldlist, collection)
+ self.attr._get_collection_history(attributes.instance_state(self.instance), passive=True).replace(oldlist, collection)
return oldlist
def append(self, item):
- self.attr.append(self.instance._state, item, None)
+ self.attr.append(attributes.instance_state(self.instance), item, None)
def remove(self, item):
- self.attr.remove(self.instance._state, item, None)
+ self.attr.remove(attributes.instance_state(self.instance), item, None)
class CollectionHistory(object):
--- /dev/null
+# exc.py - ORM exceptions
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""SQLAlchemy ORM exceptions."""
+
+import sqlalchemy.exceptions as sa_exc
+
+
+class ConcurrentModificationError(sa_exc.SQLAlchemyError):
+ """Rows have been modified outside of the unit of work."""
+
+
+class FlushError(sa_exc.SQLAlchemyError):
+ """A invalid condition was detected during flush()."""
+
+
+class ObjectDeletedError(sa_exc.InvalidRequestError):
+ """An refresh() operation failed to re-retrieve an object's row."""
+
+
+class UnmappedColumnError(sa_exc.InvalidRequestError):
+ """Mapping operation was requested on an unknown column."""
+
+
+# Legacy compat until 0.6.
+sa_exc.ConcurrentModificationError = ConcurrentModificationError
+sa_exc.FlushError = FlushError
+sa_exc.UnmappedColumnError
--- /dev/null
+# identity.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import weakref
+
+from sqlalchemy import util as base_util
+from sqlalchemy.orm import attributes
+
+
+class IdentityMap(dict):
+ def __init__(self):
+ self._mutable_attrs = weakref.WeakKeyDictionary()
+ self.modified = False
+
+ def add(self, state):
+ raise NotImplementedError()
+
+ def remove(self, state):
+ raise NotImplementedError()
+
+ def update(self, dict):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def clear(self):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def _manage_incoming_state(self, state):
+ if state.modified:
+ self.modified = True
+ if state.manager.mutable_attributes:
+ self._mutable_attrs[state] = True
+
+ def _manage_removed_state(self, state):
+ if state in self._mutable_attrs:
+ del self._mutable_attrs[state]
+
+ def check_modified(self):
+ """return True if any InstanceStates present have been marked as 'modified'."""
+
+ if not self.modified:
+ for state in self._mutable_attrs:
+ if state.check_modified():
+ return True
+ else:
+ return False
+ else:
+ return True
+
+ def has_key(self, key):
+ return key in self
+
+ def popitem(self):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def pop(self, key, *args):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def setdefault(self, key, default=None):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def copy(self):
+ raise NotImplementedError()
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def __delitem__(self, key):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+class WeakInstanceDict(IdentityMap):
+
+ def __init__(self):
+ IdentityMap.__init__(self)
+ self._wr = weakref.ref(self)
+ # RLock because the mutex is used by a cleanup
+ # handler, which can be called at any time (including within an already mutexed block)
+ self._mutex = base_util.threading.RLock()
+
+ def __getitem__(self, key):
+ state = dict.__getitem__(self, key)
+ o = state.obj()
+ if o is None:
+ o = state._check_resurrect(self)
+ if o is None:
+ raise KeyError, key
+ return o
+
+ def __contains__(self, key):
+ try:
+ state = dict.__getitem__(self, key)
+ o = state.obj()
+ if o is None:
+ o = state._check_resurrect(self)
+ except KeyError:
+ return False
+ return o is not None
+
+ def contains_state(self, state):
+ return dict.get(self, state.key) is state
+
+ def add(self, state):
+ if state.key in self:
+ if dict.__getitem__(self, state.key) is not state:
+ raise AssertionError("A conflicting state is already present in the identity map for key %r" % state.key)
+ else:
+ dict.__setitem__(self, state.key, state)
+ state._instance_dict = self._wr
+ self._manage_incoming_state(state)
+
+ def remove_key(self, key):
+ state = dict.__getitem__(self, key)
+ self.remove(state)
+
+ def remove(self, state):
+ if not self.contains_state(state):
+ raise AssertionError("State %s is not present in this identity map" % state)
+ dict.__delitem__(self, state.key)
+ del state._instance_dict
+ self._manage_removed_state(state)
+
+ def discard(self, state):
+ if self.contains_state(state):
+ dict.__delitem__(self, state.key)
+ del state._instance_dict
+ self._manage_removed_state(state)
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def items(self):
+ return list(self.iteritems())
+
+ def iteritems(self):
+ for state in dict.itervalues(self):
+ value = state.obj()
+ if value is not None:
+ yield state.key, value
+
+ def itervalues(self):
+ for state in dict.itervalues(self):
+ instance = state.obj()
+ if instance is not None:
+ yield instance
+
+ def values(self):
+ return list(self.itervalues())
+
+ def all_states(self):
+ return dict.values(self)
+
+ def prune(self):
+ return 0
+
+class StrongInstanceDict(IdentityMap):
+ def all_states(self):
+ return [attributes.instance_state(o) for o in self.values()]
+
+ def contains_state(self, state):
+ return state.key in self and attributes.instance_state(self[state.key]) is state
+
+ def add(self, state):
+ dict.__setitem__(self, state.key, state.obj())
+ self._manage_incoming_state(state)
+
+ def remove(self, state):
+ if not self.contains_state(state):
+ raise AssertionError("State %s is not present in this identity map" % state)
+ dict.__delitem__(self, state.key)
+ self._manage_removed_state(state)
+
+ def discard(self, state):
+ if self.contains_state(state):
+ dict.__delitem__(self, state.key)
+ self._manage_removed_state(state)
+
+ def remove_key(self, key):
+ state = dict.__getitem__(self, key)
+ self.remove(state)
+
+ def prune(self):
+ """prune unreferenced, non-dirty states."""
+
+ ref_count = len(self)
+ dirty = [s.obj() for s in self.all_states() if s.check_modified()]
+ keepers = weakref.WeakValueDictionary(self)
+ dict.clear(self)
+ dict.update(self, keepers)
+ self.modified = bool(dirty)
+ return ref_count - len(self)
+
+class IdentityManagedState(attributes.InstanceState):
+ def _instance_dict(self):
+ return None
+
+ def _check_resurrect(self, instance_dict):
+ instance_dict._mutex.acquire()
+ try:
+ return self.obj() or self.__resurrect(instance_dict)
+ finally:
+ instance_dict._mutex.release()
+
+ def modified_event(self, attr, should_copy, previous, passive=False):
+ attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive)
+
+ instance_dict = self._instance_dict()
+ if instance_dict:
+ instance_dict.modified = True
+
+ def _cleanup(self, ref):
+ # tiptoe around Python GC unpredictableness
+ try:
+ instance_dict = self._instance_dict()
+ instance_dict._mutex.acquire()
+ except:
+ return
+ # the mutexing here is based on the assumption that gc.collect()
+ # may be firing off cleanup handlers in a different thread than that
+ # which is normally operating upon the instance dict.
+ try:
+ try:
+ self.__resurrect(instance_dict)
+ except:
+ # catch app cleanup exceptions. no other way around this
+ # without warnings being produced
+ pass
+ finally:
+ instance_dict._mutex.release()
+
+ def __resurrect(self, instance_dict):
+ if self.check_modified():
+ # store strong ref'ed version of the object; will revert
+ # to weakref when changes are persisted
+ obj = self.manager.new_instance(state=self)
+ self.obj = weakref.ref(obj, self._cleanup)
+ self._strong_obj = obj
+ # todo: revisit this wrt user-defined-state
+ obj.__dict__.update(self.dict)
+ self.dict = obj.__dict__
+ self._run_on_load(obj)
+ return obj
+ else:
+ instance_dict.remove(self)
+ self.dispose()
+ return None
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Semi-private implementation objects which form the basis
-of ORM-mapped attributes, query options and mapper extension.
+"""
+
+Semi-private implementation objects which form the basis of ORM-mapped
+attributes, query options and mapper extension.
+
+Defines the [sqlalchemy.orm.interfaces#MapperExtension] class, which can be
+end-user subclassed to add event-based functionality to mappers. The
+remainder of this module is generally private to the ORM.
-Defines the [sqlalchemy.orm.interfaces#MapperExtension] class,
-which can be end-user subclassed to add event-based functionality
-to mappers. The remainder of this module is generally private to the
-ORM.
"""
from itertools import chain
-from sqlalchemy import exceptions, logging, util
+
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import log, util
from sqlalchemy.sql import expression
-class_mapper = None
-__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
- 'MapperProperty', 'PropComparator', 'StrategizedProperty',
- 'build_path', 'MapperOption',
- 'ExtensionOption', 'PropertyOption',
- 'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ]
+class_mapper = None
+collections = None
+
+__all__ = (
+ 'AttributeExtension',
+ 'EXT_CONTINUE',
+ 'EXT_STOP',
+ 'ExtensionOption',
+ 'InstrumentationManager',
+ 'LoaderStrategy',
+ 'MapperExtension',
+ 'MapperOption',
+ 'MapperProperty',
+ 'PropComparator',
+ 'PropertyOption',
+ 'SessionExtension',
+ 'StrategizedOption',
+ 'StrategizedProperty',
+ 'build_path',
+ )
-EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE')
+EXT_CONTINUE = util.symbol('EXT_CONTINUE')
EXT_STOP = util.symbol('EXT_STOP')
ONETOMANY = util.symbol('ONETOMANY')
these exception cases, any return value other than EXT_CONTINUE or
EXT_STOP will be interpreted as equivalent to EXT_STOP.
- EXT_PASS is a synonym for EXT_CONTINUE and is provided for backward
- compatibility.
"""
-
def instrument_class(self, mapper, class_):
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
return EXT_CONTINUE
- def get_session(self):
- """Retrieve a contextual Session instance with which to
- register a new object.
-
- Note: this is not called if a session is provided with the
- `__init__` params (i.e. `_sa_session`).
- """
-
- return EXT_CONTINUE
-
def load(self, query, *args, **kwargs):
"""Override the `load` method of the Query object.
return EXT_CONTINUE
- def get_by(self, query, *args, **kwargs):
- """Override the `get_by` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.get_by()`` if the value is anything other than
- EXT_CONTINUE.
-
- DEPRECATED.
- """
-
- return EXT_CONTINUE
-
- def select_by(self, query, *args, **kwargs):
- """Override the `select_by` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.select_by()`` if the value is anything other than
- EXT_CONTINUE.
-
- DEPRECATED.
- """
-
- return EXT_CONTINUE
-
- def select(self, query, *args, **kwargs):
- """Override the `select` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.select()`` if the value is anything other than
- EXT_CONTINUE.
-
- DEPRECATED.
- """
-
- return EXT_CONTINUE
-
-
def translate_row(self, mapper, context, row):
"""Perform pre-processing on the given result row and return a
new row instance.
return EXT_CONTINUE
+class SessionExtension(object):
+ """An extension hook object for Sessions. Subclasses may be installed into a Session
+ (or sessionmaker) using the ``extension`` keyword argument.
+ """
+
+ def before_commit(self, session):
+ """Execute right before commit is called.
+
+ Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+ def after_commit(self, session):
+ """Execute after a commit has occured.
+
+ Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+ def after_rollback(self, session):
+ """Execute after a rollback has occured.
+
+ Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+ def before_flush(self, session, flush_context, instances):
+ """Execute before flush process has started.
+
+ `instances` is an optional list of objects which were passed to the ``flush()``
+ method.
+ """
+
+ def after_flush(self, session, flush_context):
+ """Execute after flush has completed, but before commit has been called.
+
+ Note that the session's state is still in pre-flush, i.e. 'new', 'dirty',
+ and 'deleted' lists still show pre-flush state as well as the history
+ settings on instance attributes."""
+
+ def after_flush_postexec(self, session, flush_context):
+ """Execute after flush has completed, and after the post-exec state occurs.
+
+ This will be when the 'new', 'dirty', and 'deleted' lists are in their final
+ state. An actual commit() may or may not have occured, depending on whether or not
+ the flush started its own transaction or participated in a larger transaction.
+ """
+
+ def after_begin(self, session, transaction, connection):
+ """Execute after a transaction is begun on a connection
+
+ `transaction` is the SessionTransaction. This method is called after an
+ engine level transaction is begun on a connection.
+ """
+
+
class MapperProperty(object):
"""Manage the relationship of a ``Mapper`` to a single class
attribute, as well as that attribute as it appears on individual
attribute access, loading behavior, and dependency calculations.
"""
- def setup(self, querycontext, **kwargs):
+ def setup(self, context, entity, path, adapter, **kwargs):
"""Called by Query for the purposes of constructing a SQL statement.
Each MapperProperty associated with the target mapper processes the
pass
- def create_row_processor(self, selectcontext, mapper, row):
- """Return a 3-tuple consiting of two row processing functions and an instance post-processing function.
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ """Return a 2-tuple consiting of two row processing functions and an instance post-processing function.
Input arguments are the query.SelectionContext and the *first*
applicable row of a result set obtained within
columns present in the row (which will be the same columns present in
all rows) are used to determine the presence and behavior of the
returned callables. The callables will then be used to process all
- rows and to post-process all instances, respectively.
+ rows and instances.
Callables are of the following form::
- def new_execute(instance, row, **flags):
- # process incoming instance and given row. the instance is
+ def new_execute(state, row, **flags):
+ # process incoming instance state and given row. the instance is
# "new" and was just created upon receipt of this row.
# flags is a dictionary containing at least the following
# attributes:
# isnew - indicates if the instance was newly created as a
# result of reading this row
# instancekey - identity key of the instance
- # optional attribute:
- # ispostselect - indicates if this row resulted from a
- # 'post' select of additional tables/columns
- def existing_execute(instance, row, **flags):
- # process incoming instance and given row. the instance is
+ def existing_execute(state, row, **flags):
+ # process incoming instance state and given row. the instance is
# "existing" and was created based on a previous row.
- def post_execute(instance, **flags):
- # process instance after all result rows have been processed.
- # this function should be used to issue additional selections
- # in order to eagerly load additional properties.
-
- return (new_execute, existing_execute, post_execute)
+ return (new_execute, existing_execute)
Either of the three tuples can be ``None`` in which case no function
is called.
return iter([])
- def get_criterion(self, query, key, value):
- """Return a ``WHERE`` clause suitable for this
- ``MapperProperty`` corresponding to the given key/value pair,
- where the key is a column or object property name, and value
- is a value to be matched. This is only picked up by
- ``PropertyLoaders``.
-
- This is called by a ``Query``'s ``join_by`` method to formulate a set
- of key/value pairs into a ``WHERE`` criterion that spans multiple
- tables if needed.
- """
-
- return None
-
def set_parent(self, parent):
self.parent = parent
which returns the MapperProperty associated with this
PropComparator.
"""
-
- def expression_element(self):
- return self.clause_element()
-
+
+ def __clause_element__(self):
+ raise NotImplementedError("%r" % self)
+
def contains_op(a, b):
return a.contains(b)
contains_op = staticmethod(contains_op)
``StrategizedOption`` objects via the Query.options() method.
"""
- def _get_context_strategy(self, context):
- path = context.path
- return self._get_strategy(context.attributes.get(("loaderstrategy", path), self.strategy.__class__))
-
+ def __get_context_strategy(self, context, path):
+ cls = context.attributes.get(("loaderstrategy", path), None)
+ if cls:
+ try:
+ return self.__all_strategies[cls]
+ except KeyError:
+ return self.__init_strategy(cls)
+ else:
+ return self.strategy
+
def _get_strategy(self, cls):
try:
- return self._all_strategies[cls]
+ return self.__all_strategies[cls]
except KeyError:
- # cache the located strategy per class for faster re-lookup
- strategy = cls(self)
- strategy.init()
- self._all_strategies[cls] = strategy
- return strategy
-
- def setup(self, querycontext, **kwargs):
- self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs)
+ return self.__init_strategy(cls)
+
+ def __init_strategy(self, cls):
+ self.__all_strategies[cls] = strategy = cls(self)
+ strategy.init()
+ return strategy
+
+ def setup(self, context, entity, path, adapter, **kwargs):
+ self.__get_context_strategy(context, path + (self.key,)).setup_query(context, entity, path, adapter, **kwargs)
- def create_row_processor(self, selectcontext, mapper, row):
- return self._get_context_strategy(selectcontext).create_row_processor(selectcontext, mapper, row)
+ def create_row_processor(self, context, path, mapper, row, adapter):
+ return self.__get_context_strategy(context, path + (self.key,)).create_row_processor(context, path, mapper, row, adapter)
def do_init(self):
- self._all_strategies = {}
- self.strategy = self._get_strategy(self.strategy_class)
+ self.__all_strategies = {}
+ self.strategy = self.__init_strategy(self.strategy_class)
if self.is_primary():
self.strategy.init_class_attribute()
-def build_path(mapper, key, prev=None):
+def build_path(entity, key, prev=None):
if prev:
- return prev + (mapper.base_mapper, key)
+ return prev + (entity, key)
else:
- return (mapper.base_mapper, key)
+ return (entity, key)
def serialize_path(path):
if path is None:
self.ext = ext
def process_query(self, query):
- query._extension = query._extension.copy()
- query._extension.insert(self.ext)
-
+ entity = query._generate_mapper_zero()
+ entity.extension = entity.extension.copy()
+ entity.extension.push(self.ext)
class PropertyOption(MapperOption):
"""A MapperOption that is applied to a property off the mapper or
def _process(self, query, raiseerr):
if self._should_log_debug:
self.logger.debug("applying option to Query, property key '%s'" % self.key)
- paths = self._get_paths(query, raiseerr)
+ paths = self.__get_paths(query, raiseerr)
if paths:
self.process_query_property(query, paths)
def process_query_property(self, query, paths):
pass
+
+ def __find_entity(self, query, mapper, raiseerr):
+ from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class
+
+ if _is_aliased_class(mapper):
+ searchfor = mapper
+ else:
+ searchfor = _class_to_mapper(mapper).base_mapper
- def _get_paths(self, query, raiseerr):
+ for ent in query._mapper_entities:
+ if ent.path_entity is searchfor:
+ return ent
+ else:
+ if raiseerr:
+ raise sa_exc.ArgumentError("Can't find entity %s in Query. Current list: %r" % (searchfor, [str(m.path_entity) for m in query._entities]))
+ else:
+ return None
+
+ def __get_paths(self, query, raiseerr):
path = None
+ entity = None
l = []
+
current_path = list(query._current_path)
-
+
if self.mapper:
- global class_mapper
- if class_mapper is None:
- from sqlalchemy.orm import class_mapper
- mapper = self.mapper
- if isinstance(self.mapper, type):
- mapper = class_mapper(mapper)
- if mapper is not query.mapper and mapper not in [q.mapper for q in query._entities]:
- raise exceptions.ArgumentError("Can't find entity %s in Query. Current list: %r" % (str(mapper), [str(m) for m in query._entities]))
- else:
- mapper = query.mapper
- if isinstance(self.key, basestring):
- tokens = self.key.split('.')
- else:
- tokens = util.to_list(self.key)
+ entity = self.__find_entity(query, self.mapper, raiseerr)
+ mapper = entity.mapper
+ path_element = entity.path_entity
- for token in tokens:
- if isinstance(token, basestring):
- prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
- elif isinstance(token, PropComparator):
- prop = token.property
- token = prop.key
-
+ for key in util.to_list(self.key):
+ if isinstance(key, basestring):
+ tokens = key.split('.')
else:
- raise exceptions.ArgumentError("mapper option expects string key or list of attributes")
-
- if current_path and token == current_path[1]:
- current_path = current_path[2:]
- continue
+ tokens = [key]
+ for token in tokens:
+ if isinstance(token, basestring):
+ if not entity:
+ entity = query._entity_zero()
+ path_element = entity.path_entity
+ mapper = entity.mapper
+ prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
+ key = token
+ elif isinstance(token, PropComparator):
+ prop = token.property
+ if not entity:
+ entity = self.__find_entity(query, token.parententity, raiseerr)
+ if not entity:
+ return []
+ path_element = entity.path_entity
+ key = prop.key
+ else:
+ raise sa_exc.ArgumentError("mapper option expects string key or list of attributes")
+
+ if current_path and key == current_path[1]:
+ current_path = current_path[2:]
+ continue
- if prop is None:
- return []
- path = build_path(mapper, prop.key, path)
- l.append(path)
- if getattr(token, '_of_type', None):
- mapper = token._of_type
- else:
- mapper = getattr(prop, 'mapper', None)
+ if prop is None:
+ return []
+
+ path = build_path(path_element, prop.key, path)
+ l.append(path)
+ if getattr(token, '_of_type', None):
+ path_element = mapper = token._of_type
+ else:
+ path_element = mapper = getattr(prop, 'mapper', None)
+ if path_element:
+ path_element = path_element.base_mapper
+
return l
-PropertyOption.logger = logging.class_logger(PropertyOption)
-PropertyOption._should_log_debug = logging.is_debug_enabled(PropertyOption.logger)
+PropertyOption.logger = log.class_logger(PropertyOption)
+PropertyOption._should_log_debug = log.is_debug_enabled(PropertyOption.logger)
class AttributeExtension(object):
"""An abstract class which specifies `append`, `delete`, and `set`
def init_class_attribute(self):
pass
- def setup_query(self, context, **kwargs):
+ def setup_query(self, context, entity, path, adapter, **kwargs):
pass
- def create_row_processor(self, selectcontext, mapper, row):
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
"""Return row processing functions which fulfill the contract specified
by MapperProperty.create_row_processor.
"""
raise NotImplementedError()
+
+ def __str__(self):
+ return str(self.parent_property)
+
+ def debug_callable(self, fn, logger, announcement, logfn):
+ if announcement:
+ logger.debug(announcement)
+ if logfn:
+ def call(*args, **kwargs):
+ logger.debug(logfn(*args, **kwargs))
+ return fn(*args, **kwargs)
+ return call
+ else:
+ return fn
+
+class InstrumentationManager(object):
+ """User-defined class instrumentation extension."""
+
+ # r4361 added a mandatory (cls) constructor to this interface.
+ # given that, perhaps class_ should be dropped from all of these
+ # signatures.
+
+ def __init__(self, class_):
+ pass
+
+ def manage(self, class_, manager):
+ setattr(class_, '_default_class_manager', manager)
+
+ def dispose(self, class_, manager):
+ delattr(class_, '_default_class_manager')
+
+ def manager_getter(self, class_):
+ def get(cls):
+ return cls._default_class_manager
+ return get
+
+ def instrument_attribute(self, class_, key, inst):
+ pass
+
+ def install_descriptor(self, class_, key, inst):
+ setattr(class_, key, inst)
+
+ def uninstall_descriptor(self, class_, key):
+ delattr(class_, key)
+
+ def install_member(self, class_, key, implementation):
+ setattr(class_, key, implementation)
+
+ def uninstall_member(self, class_, key):
+ delattr(class_, key)
+
+ def instrument_collection_class(self, class_, key, collection_class):
+ global collections
+ if collections is None:
+ from sqlalchemy.orm import collections
+ return collections.prepare_instrumentation(collection_class)
+
+ def get_instance_dict(self, class_, instance):
+ return instance.__dict__
+
+ def initialize_instance_dict(self, class_, instance):
+ pass
+
+ def install_state(self, class_, instance, state):
+ setattr(instance, '_default_state', state)
+
+ def state_getter(self, class_):
+ return lambda instance: getattr(instance, '_default_state')
-# orm/mapper.py
+# mapper.py
# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Defines the [sqlalchemy.orm.mapper#Mapper] class, the central configurational
+"""Logic to map Python classes to and from selectables.
+
+Defines the [sqlalchemy.orm.mapper#Mapper] class, the central configurational
unit which associates a class with a database table.
This is a semi-private module; the main configurational API of the ORM is
available in [sqlalchemy.orm#].
+
"""
import weakref
from itertools import chain
-from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy.sql import expression, visitors, operators, util as sqlutil
-from sqlalchemy.orm import sync, attributes
-from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator
-from sqlalchemy.orm.util import has_identity, _state_has_identity, _is_mapped_class, has_mapper, \
- _state_mapper, class_mapper, object_mapper, _class_to_mapper,\
- ExtensionCarrier, state_str, instance_str
-
-__all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry']
+
+from sqlalchemy import sql, util, log
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy.sql import expression, visitors, operators
+import sqlalchemy.sql.util as sqlutil
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import exc
+from sqlalchemy.orm import sync
+from sqlalchemy.orm.identity import IdentityManagedState
+from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, \
+ PropComparator
+from sqlalchemy.orm.util import \
+ ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _is_mapped_class, \
+ _state_has_identity, _state_mapper, class_mapper, has_identity, \
+ has_mapper, instance_str, object_mapper, state_str
+
+
+__all__ = (
+ 'Mapper',
+ '_mapper_registry',
+ 'class_mapper',
+ 'object_mapper',
+ )
_mapper_registry = weakref.WeakKeyDictionary()
_new_mappers = False
SynonymProperty = None
ComparableProperty = None
_expire_state = None
+_state_session = None
class Mapper(object):
Mappers are normally constructed via the [sqlalchemy.orm#mapper()]
function. See for details.
-
+
"""
self.class_ = class_
+ self.class_manager = None
self.entity_name = entity_name
self.primary_key_argument = primary_key
self.non_primary = non_primary
self.eager_defaults = eager_defaults
self.column_prefix = column_prefix
self.polymorphic_on = polymorphic_on
- self._eager_loaders = util.Set()
self._dependency_processors = []
self._clause_adapter = None
self._requires_row_aliasing = False
self.__inherits_equated_pairs = None
-
+
if not issubclass(class_, object):
- raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
+ raise sa_exc.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
self.select_table = select_table
if select_table:
if with_polymorphic:
- raise exceptions.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)")
+ raise sa_exc.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)")
self.with_polymorphic = ('*', select_table)
else:
if with_polymorphic == '*':
else:
self.with_polymorphic = (with_polymorphic, None)
elif with_polymorphic is not None:
- raise exceptions.ArgumentError("Invalid setting for with_polymorphic")
+ raise sa_exc.ArgumentError("Invalid setting for with_polymorphic")
else:
self.with_polymorphic = None
-
+
if isinstance(self.local_table, expression._SelectBaseMixin):
util.warn("mapper %s creating an alias for the given selectable - use Class attributes for queries." % self)
self.local_table = self.local_table.alias()
-
+
if self.with_polymorphic and isinstance(self.with_polymorphic[1], expression._SelectBaseMixin):
self.with_polymorphic[1] = self.with_polymorphic[1].alias()
# indicates this Mapper should be used to construct the object instance for that row.
self.polymorphic_identity = polymorphic_identity
- if polymorphic_fetch not in (None, 'union', 'select', 'deferred'):
- raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch)
- if polymorphic_fetch is None:
- self.polymorphic_fetch = (self.with_polymorphic is None) and 'select' or 'union'
- else:
- self.polymorphic_fetch = polymorphic_fetch
+ if polymorphic_fetch:
+ util.warn_deprecated('polymorphic_fetch option is deprecated. Unloaded columns load as deferred in all cases; loading can be controlled using the "with_polymorphic" option.')
# a dictionary of 'polymorphic identity' names, associating those names with
# Mappers that will be used to construct object instances upon a select operation.
# a set of all mappers which inherit from this one.
self._inheriting_mappers = util.Set()
- self.__props_init = False
+ self.compiled = False
- self.__should_log_info = logging.is_info_enabled(self.logger)
- self.__should_log_debug = logging.is_debug_enabled(self.logger)
+ self.__should_log_info = log.is_info_enabled(self.logger)
+ self.__should_log_debug = log.is_debug_enabled(self.logger)
- self.__compile_class()
self.__compile_inheritance()
self.__compile_extensions()
+ self.__compile_class()
self.__compile_properties()
self.__compile_pks()
global _new_mappers
if self.__should_log_debug:
self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg)
- def _is_orphan(self, obj):
+ def _is_orphan(self, state):
o = False
for mapper in self.iterate_to_root():
- for (key,klass) in mapper.delete_orphans:
- if attributes.has_parent(klass, obj, key, optimistic=has_identity(obj)):
+ for (key, cls) in mapper.delete_orphans:
+ if attributes.manager_of_class(cls).has_parent(
+ state, key, optimistic=_state_has_identity(state)):
return False
o = o or bool(mapper.delete_orphans)
return o
return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr)
def _get_property(self, key, resolve_synonyms=False, raiseerr=True):
- """private in-compilation version of get_property()."""
-
prop = self.__props.get(key, None)
if resolve_synonyms:
while isinstance(prop, SynonymProperty):
prop = self.__props.get(prop.name, None)
if prop is None and raiseerr:
- raise exceptions.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key))
+ raise sa_exc.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key))
return prop
def iterate_properties(self):
+ """return an iterator of all MapperProperty objects."""
self.compile()
return self.__props.itervalues()
- iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.")
+ iterate_properties = property(iterate_properties)
- def __adjust_wp_selectable(self, spec=None, selectable=False):
- """given a with_polymorphic() argument, resolve it against this mapper's with_polymorphic setting"""
-
- isdefault = False
- if self.with_polymorphic:
- isdefault = not spec and selectable is False
-
- if not spec:
- spec = self.with_polymorphic[0]
- if selectable is False:
- selectable = self.with_polymorphic[1]
-
- return spec, selectable, isdefault
-
def __mappers_from_spec(self, spec, selectable):
"""given a with_polymorphic() argument, return the set of mappers it represents.
-
+
Trims the list of mappers to just those represented within the given selectable, if present.
This helps some more legacy-ish mappings.
-
+
"""
if spec == '*':
mappers = list(self.polymorphic_iterator())
mappers = [_class_to_mapper(m) for m in util.to_list(spec)]
else:
mappers = []
-
+
if selectable:
- tables = util.Set(sqlutil.find_tables(selectable))
+ tables = util.Set(sqlutil.find_tables(selectable, include_aliases=True))
mappers = [m for m in mappers if m.local_table in tables]
-
+
return mappers
- __mappers_from_spec = util.conditional_cache_decorator(__mappers_from_spec)
-
+
def __selectable_from_mappers(self, mappers):
"""given a list of mappers (assumed to be within this mapper's inheritance hierarchy),
construct an outerjoin amongst those mapper's mapped tables.
-
+
"""
from_obj = self.mapped_table
for m in mappers:
if m is self:
continue
if m.concrete:
- raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.")
+ raise sa_exc.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.")
elif not m.single:
from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition)
-
+
return from_obj
- __selectable_from_mappers = util.conditional_cache_decorator(__selectable_from_mappers)
-
- def _with_polymorphic_mappers(self, spec=None, selectable=False):
- spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
- return self.__mappers_from_spec(spec, selectable, cache=isdefault)
-
- def _with_polymorphic_selectable(self, spec=None, selectable=False):
- spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
+
+ def _with_polymorphic_mappers(self):
+ if not self.with_polymorphic:
+ return [self]
+ return self.__mappers_from_spec(*self.with_polymorphic)
+ _with_polymorphic_mappers = property(util.cache_decorator(_with_polymorphic_mappers))
+
+ def _with_polymorphic_selectable(self):
+ if not self.with_polymorphic:
+ return self.mapped_table
+
+ spec, selectable = self.with_polymorphic
if selectable:
return selectable
else:
- return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable, cache=isdefault), cache=isdefault)
-
+ return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable))
+ _with_polymorphic_selectable = property(util.cache_decorator(_with_polymorphic_selectable))
+
def _with_polymorphic_args(self, spec=None, selectable=False):
- spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
- mappers = self.__mappers_from_spec(spec, selectable, cache=isdefault)
+ if self.with_polymorphic:
+ if not spec:
+ spec = self.with_polymorphic[0]
+ if selectable is False:
+ selectable = self.with_polymorphic[1]
+
+ mappers = self.__mappers_from_spec(spec, selectable)
if selectable:
return mappers, selectable
else:
- return mappers, self.__selectable_from_mappers(mappers, cache=isdefault)
-
- def _iterate_polymorphic_properties(self, spec=None, selectable=False):
+ return mappers, self.__selectable_from_mappers(mappers)
+
+ def _iterate_polymorphic_properties(self, mappers=None):
+ if mappers is None:
+ mappers = self._with_polymorphic_mappers
return iter(util.OrderedSet(
- chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers(spec, selectable)])
+ chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers])
))
def properties(self):
raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.")
properties = property(properties)
- def compiled(self):
- """return True if this mapper is compiled"""
- return self.__props_init
- compiled = property(compiled)
-
def dispose(self):
- # disaable any attribute-based compilation
- self.__props_init = True
- try:
- del self.class_.c
- except AttributeError:
- pass
- if not self.non_primary and self.entity_name in self._class_state.mappers:
- del self._class_state.mappers[self.entity_name]
- if not self._class_state.mappers:
+ # Disable any attribute-based compilation.
+ self.compiled = True
+
+ manager = self.class_manager
+ mappers = manager.mappers
+
+ if not self.non_primary and self.entity_name in mappers:
+ del mappers[self.entity_name]
+ if not mappers and manager.info.get(_INSTRUMENTOR, False):
+ for legacy in _legacy_descriptors.keys():
+ manager.uninstall_member(legacy)
+ manager.events.remove_listener('on_init', _event_on_init)
+ manager.events.remove_listener('on_init_failure',
+ _event_on_init_failure)
+ manager.uninstall_member('__init__')
+ del manager.info[_INSTRUMENTOR]
attributes.unregister_class(self.class_)
def compile(self):
"""Compile this mapper and all other non-compiled mappers.
-
+
This method checks the local compiled status as well as for
- any new mappers that have been defined, and is safe to call
+ any new mappers that have been defined, and is safe to call
repeatedly.
"""
-
global _new_mappers
- if self.__props_init and not _new_mappers:
+ if self.compiled and not _new_mappers:
return self
_COMPILE_MUTEX.acquire()
try:
# double-check inside mutex
- if self.__props_init and not _new_mappers:
+ if self.compiled and not _new_mappers:
return self
# initialize properties on all mappers
for mapper in list(_mapper_registry):
- if not mapper.__props_init:
+ if not mapper.compiled:
mapper.__initialize_properties()
_new_mappers = False
def __initialize_properties(self):
"""Call the ``init()`` method on all ``MapperProperties``
attached to this mapper.
-
+
This is a deferred configuration step which is intended
to execute once all mappers have been constructed.
"""
if getattr(prop, 'key', None) is None:
prop.init(key, self)
self.__log("__initialize_properties() complete")
- self.__props_init = True
-
+ self.compiled = True
def __compile_extensions(self):
"""Go through the global_extensions list as well as the list
for ext in self.inherits.extension:
if ext not in extlist:
extlist.add(ext)
- ext.instrument_class(self, self.class_)
else:
for ext in global_extensions:
if isinstance(ext, type):
ext = ext()
if ext not in extlist:
extlist.add(ext)
- ext.instrument_class(self, self.class_)
self.extension = ExtensionCarrier()
for ext in extlist:
if self.inherits:
if isinstance(self.inherits, type):
self.inherits = class_mapper(self.inherits, compile=False)
- else:
- self.inherits = self.inherits
if not issubclass(self.class_, self.inherits.class_):
- raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__))
+ raise sa_exc.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__))
if self.non_primary != self.inherits.non_primary:
np = not self.non_primary and "primary" or "non-primary"
- raise exceptions.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np))
+ raise sa_exc.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np))
# inherit_condition is optional.
if self.local_table is None:
self.local_table = self.inherits.local_table
if mapper.polymorphic_on:
mapper._requires_row_aliasing = True
else:
- if self.inherit_condition is None:
+ if not self.inherit_condition:
# figure out inherit condition from our table to the immediate table
# of the inherited mapper, not its full table which could pull in other
# stuff we dont want (allows test/inheritance.InheritTest4 to pass)
self.inherit_condition = sqlutil.join_condition(self.inherits.local_table, self.local_table)
self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition)
-
+
fks = util.to_set(self.inherit_foreign_keys)
self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks)
else:
self.mapped_table = self.local_table
- if self.polymorphic_identity is not None:
- self.inherits.polymorphic_map[self.polymorphic_identity] = self
- if self.polymorphic_on is None:
- for mapper in self.iterate_to_root():
- # try to set up polymorphic on using correesponding_column(); else leave
- # as None
- if mapper.polymorphic_on:
- self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on)
- break
- else:
- # TODO: this exception not covered
- raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
if self.polymorphic_identity and not self.concrete:
self._identity_class = self.inherits._identity_class
self.inherits._inheriting_mappers.add(self)
self.base_mapper = self.inherits.base_mapper
self._all_tables = self.inherits._all_tables
+
+ if self.polymorphic_identity is not None:
+ self.polymorphic_map[self.polymorphic_identity] = self
+ if not self.polymorphic_on:
+ for mapper in self.iterate_to_root():
+ # try to set up polymorphic on using correesponding_column(); else leave
+ # as None
+ if mapper.polymorphic_on:
+ self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on)
+ break
+ else:
+ # TODO: this exception not covered
+ raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
else:
self._all_tables = util.Set()
self.base_mapper = self
self.mapped_table = self.local_table
if self.polymorphic_identity:
if self.polymorphic_on is None:
- raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
+ raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
self.polymorphic_map[self.polymorphic_identity] = self
self._identity_class = self.class_
-
+
if self.mapped_table is None:
- raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self))
+ raise sa_exc.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self))
def __compile_pks(self):
self.tables = sqlutil.find_tables(self.mapped_table)
if not self.tables:
- raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
+ raise sa_exc.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
self._pks_by_table = {}
self._cols_by_table = {}
self._pks_by_table[k.table].add(k)
if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0:
- raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
+ raise sa_exc.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
if self.inherits and not self.concrete and not self.primary_key_argument:
# if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit)
primary_key = sqlutil.reduce_columns(self._pks_by_table[self.mapped_table])
if len(primary_key) == 0:
- raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
+ raise sa_exc.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
self.primary_key = primary_key
self.__log("Identified primary key columns: " + str(primary_key))
"""create a "get clause" based on the primary key. this is used
by query.get() and many-to-one lazyloads to load this item
by primary key.
-
+
"""
params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key]
return sql.and_(*[k==v for (k, v) in params]), dict(params)
_get_clause = property(util.cache_decorator(_get_clause))
-
+
def _equivalent_columns(self):
"""Create a map of all *equivalent* columns, based on
the determination of column pairs that are equated to
one another either by an established foreign key relationship
or by a joined-table inheritance join.
- This is used to determine the minimal set of primary key
- columns for the mapper, as well as when relating
- columns to those of a polymorphic selectable (i.e. a UNION of
- several mapped tables), as that selectable usually only contains
- one column in its columns clause out of a group of several which
- are equated to each other.
-
The resulting structure is a dictionary of columns mapped
to lists of equivalent columns, i.e.
result[binary.right] = util.Set([binary.left])
for mapper in self.base_mapper.polymorphic_iterator():
if mapper.inherit_condition:
- visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary)
+ visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary})
# TODO: matching of cols to foreign keys might better be generalized
# into general column translation (i.e. corresponding_column)
cls = object.__getattribute__(self, 'class_')
clskey = object.__getattribute__(self, 'key')
- if key.startswith('__'):
+ if key.startswith('__') and key != '__clause_element__':
return object.__getattribute__(self, key)
class_mapper(cls)
column_key = (self.column_prefix or '') + column.key
self._compile_property(column_key, column, init=False, setparent=True)
-
+
# do a special check for the "discriminiator" column, as it may only be present
# in the 'with_polymorphic' selectable but we need it for the base mapper
- if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty:
- col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on
- self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True)
-
+ if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty:
+ col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on
+ self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True)
+
def _adapt_inherited_property(self, key, prop):
if not self.concrete:
self._compile_property(key, prop, init=False, setparent=False)
columns = util.to_list(prop)
column = columns[0]
if not expression.is_column(column):
- raise exceptions.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop))
+ raise sa_exc.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop))
prop = self.__props.get(key, None)
for c in columns:
mc = self.mapped_table.corresponding_column(c)
if not mc:
- raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c))
+ raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c))
mapped_column.append(mc)
prop = ColumnProperty(*mapped_column)
else:
if not self.allow_column_override:
- raise exceptions.ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop)))
+ raise sa_exc.ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop)))
else:
return
if col is None:
col = prop.columns[0]
else:
- # if column is coming in after _cols_by_table was initialized, ensure the col is in the
+ # if column is coming in after _cols_by_table was initialized, ensure the col is in the
# right set
if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]:
self._cols_by_table[col.table].add(col)
for col in prop.columns:
for col in col.proxy_set:
self._columntoproperty[col] = prop
-
-
- elif isinstance(prop, SynonymProperty) and setparent:
+
+ elif isinstance(prop, (ComparableProperty, SynonymProperty)) and setparent:
if prop.descriptor is None:
prop.descriptor = getattr(self.class_, key, None)
if isinstance(prop.descriptor, Mapper._CompileOnAttr):
prop.descriptor = object.__getattribute__(prop.descriptor, 'existing_prop')
- if prop.map_column:
- if not key in self.mapped_table.c:
- raise exceptions.ArgumentError("Can't compile synonym '%s': no column on table '%s' named '%s'" % (prop.name, self.mapped_table.description, key))
+ if getattr(prop, 'map_column', False):
+ if key not in self.mapped_table.c:
+ raise sa_exc.ArgumentError("Can't compile synonym '%s': no column on table '%s' named '%s'" % (prop.name, self.mapped_table.description, key))
self._compile_property(prop.name, ColumnProperty(self.mapped_table.c[key]), init=init, setparent=setparent)
- elif isinstance(prop, ComparableProperty) and setparent:
- # refactor me
- if prop.descriptor is None:
- prop.descriptor = getattr(self.class_, key, None)
- if isinstance(prop.descriptor, Mapper._CompileOnAttr):
- prop.descriptor = object.__getattribute__(prop.descriptor,
- 'existing_prop')
+
self.__props[key] = prop
if setparent:
prop.set_parent(self)
if not self.non_primary:
- setattr(self.class_, key, Mapper._CompileOnAttr(self.class_, key))
-
+ self.class_manager.install_descriptor(
+ key, Mapper._CompileOnAttr(self.class_, key))
if init:
prop.init(key, self)
-
+
for mapper in self._inheriting_mappers:
mapper._adapt_inherited_property(key, prop)
auto-session attachment logic.
"""
+ manager = attributes.manager_of_class(self.class_)
+
if self.non_primary:
- if not hasattr(self.class_, '_class_state'):
- raise exceptions.InvalidRequestError("Class %s has no primary mapper configured. Configure a primary mapper first before setting up a non primary Mapper.")
- self._class_state = self.class_._class_state
+ if not manager or None not in manager.mappers:
+ raise sa_exc.InvalidRequestError(
+ "Class %s has no primary mapper configured. Configure "
+ "a primary mapper first before setting up a non primary "
+ "Mapper.")
+ self.class_manager = manager
_mapper_registry[self] = True
return
- if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers):
- raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper. clear_mappers() will remove *all* current mappers from all classes." % (self.class_, self.entity_name))
+ if manager is not None:
+ if manager.class_ is not self.class_:
+ # An inherited manager. Install one for this subclass.
+ manager = None
+ elif self.entity_name in manager.mappers:
+ raise sa_exc.ArgumentError(
+ "Class '%s' already has a primary mapper defined "
+ "with entity name '%s'. Use non_primary=True to "
+ "create a non primary Mapper. clear_mappers() will "
+ "remove *all* current mappers from all classes." %
+ (self.class_, self.entity_name))
- def extra_init(class_, oldinit, instance, args, kwargs):
- self.compile()
- if 'init_instance' in self.extension.methods:
- self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
+ _mapper_registry[self] = True
- def on_exception(class_, oldinit, instance, args, kwargs):
- util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
+ if manager is None:
+ manager = attributes.create_manager_for_cls(self.class_)
- attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes)
+ self.class_manager = manager
- self._class_state = self.class_._class_state
- _mapper_registry[self] = True
+ has_been_initialized = bool(manager.info.get(_INSTRUMENTOR, False))
+ manager.mappers[self.entity_name] = self
- self.class_._class_state.mappers[self.entity_name] = self
+ # The remaining members can be added by any mapper, e_name None or not.
+ if has_been_initialized:
+ return
- for ext in util.to_list(self.extension, []):
- ext.instrument_class(self, self.class_)
+ self.extension.instrument_class(self, self.class_)
- if self.entity_name is None:
- self.class_.c = self.c
+ manager.instantiable = True
+ manager.instance_state_factory = IdentityManagedState
+ manager.deferred_scalar_loader = _load_scalar_attributes
+
+ event_registry = manager.events
+ event_registry.add_listener('on_init', _event_on_init)
+ event_registry.add_listener('on_init_failure', _event_on_init_failure)
+
+ for key, impl in _legacy_descriptors.items():
+ manager.install_member(key, impl)
+
+ manager.info[_INSTRUMENTOR] = self
def common_parent(self, other):
"""Return true if the given mapper shares a common inherited parent as this mapper."""
return self.base_mapper is other.base_mapper
+ def _canload(self, state):
+ s = self.primary_mapper()
+ if s.polymorphic_on:
+ return _state_mapper(state).isa(s)
+ else:
+ return _state_mapper(state) is s
+
def isa(self, other):
- """Return True if the given mapper inherits from this mapper."""
+ """Return True if the this mapper inherits from the given mapper."""
- m = other
- while m is not self and m.inherits:
+ m = self
+ while m and m is not other:
m = m.inherits
- return m is self
+ return bool(m)
def iterate_to_root(self):
m = self
"""
self._init_properties[key] = prop
- self._compile_property(key, prop, init=self.__props_init)
+ self._compile_property(key, prop, init=self.compiled)
+
+ def __repr__(self):
+ return '<Mapper at 0x%x; %s>' % (
+ id(self), self.class_.__name__)
def __str__(self):
return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "")
def primary_mapper(self):
"""Return the primary mapper corresponding to this mapper's class key (class + entity_name)."""
- return self._class_state.mappers[self.entity_name]
-
- def get_session(self):
- """Return the contextual session provided by the mapper
- extension chain, if any.
-
- Raise ``InvalidRequestError`` if a session cannot be retrieved
- from the extension chain.
- """
-
- if 'get_session' in self.extension.methods:
- s = self.extension.get_session()
- if s is not EXT_CONTINUE:
- return s
-
- raise exceptions.InvalidRequestError("No contextual Session is established.")
-
- def instances(self, cursor, session, *mappers, **kwargs):
- """Return a list of mapped instances corresponding to the rows
- in a given ResultProxy.
-
- DEPRECATED.
- """
+ return self.class_manager.mappers[self.entity_name]
- import sqlalchemy.orm.query
- return sqlalchemy.orm.Query(self, session).instances(cursor, *mappers, **kwargs)
- instances = util.deprecated(None, False)(instances)
-
- def identity_key_from_row(self, row):
+ def identity_key_from_row(self, row, adapter=None):
"""Return an identity-map key for use in storing/retrieving an
item from the identity map.
dictionary corresponding result-set ``ColumnElement``
instances to their values within a row.
"""
- return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name)
+
+ pk_cols = self.primary_key
+ if adapter:
+ pk_cols = [adapter.columns[c] for c in pk_cols]
+
+ return (self._identity_class, tuple([row[column] for column in pk_cols]), self.entity_name)
def identity_key_from_primary_key(self, primary_key):
"""Return an identity-map key for use in storing/retrieving an
"""Return the identity key for the given instance, based on
its primary key attributes.
- This value is typically also found on the instance itself
- under the attribute name `_instance_key`.
+ This value is typically also found on the instance state under the
+ attribute name `key`.
+
"""
return self.identity_key_from_primary_key(self.primary_key_from_instance(instance))
"""Return the list of primary key values for the given
instance.
"""
-
- return [self._get_state_attr_by_column(instance._state, column) for column in self.primary_key]
+ state = attributes.instance_state(instance)
+ return self._primary_key_from_state(state)
def _primary_key_from_state(self, state):
return [self._get_state_attr_by_column(state, column) for column in self.primary_key]
- def _canload(self, state):
- if self.polymorphic_on:
- return issubclass(state.class_, self.class_)
- else:
- return state.class_ is self.class_
def _get_col_to_prop(self, column):
try:
except KeyError:
prop = self.__props.get(column.key, None)
if prop:
- raise exceptions.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop)))
+ raise exc.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop)))
else:
- raise exceptions.UnmappedColumnError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
+ raise exc.UnmappedColumnError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
+ # TODO: improve names
def _get_state_attr_by_column(self, state, column):
return self._get_col_to_prop(column).getattr(state, column)
def _set_state_attr_by_column(self, state, column, value):
return self._get_col_to_prop(column).setattr(state, value, column)
- def _get_attr_by_column(self, obj, column):
- return self._get_col_to_prop(column).getattr(obj._state, column)
-
def _get_committed_attr_by_column(self, obj, column):
- return self._get_col_to_prop(column).getcommitted(obj._state, column)
+ state = attributes.instance_state(obj)
+ return self._get_committed_state_attr_by_column(state, column)
- def _set_attr_by_column(self, obj, column, value):
- self._get_col_to_prop(column).setattr(obj._state, column, value)
+ def _get_committed_state_attr_by_column(self, state, column):
+ return self._get_col_to_prop(column).getcommitted(state, column)
def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
# organize individual states with the connection to use for insert/update
if 'connection_callable' in uowtransaction.mapper_flush_opts:
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
- tups = [(state, connection_callable(self, state.obj()), _state_has_identity(state)) for state in states]
+ tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in states]
else:
connection = uowtransaction.transaction.connection(self)
- tups = [(state, connection, _state_has_identity(state)) for state in states]
+ tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in states]
if not postupdate:
# call before_XXX extensions
- for state, connection, has_identity in tups:
- mapper = _state_mapper(state)
+ for state, mapper, connection, has_identity in tups:
if not has_identity:
if 'before_insert' in mapper.extension.methods:
mapper.extension.before_insert(mapper, connection, state.obj())
if 'before_update' in mapper.extension.methods:
mapper.extension.before_update(mapper, connection, state.obj())
- for state, connection, has_identity in tups:
+ for state, mapper, connection, has_identity in tups:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
# and another instance with the same identity key already exists as persistent. convert to an
# UPDATE if so.
- mapper = _state_mapper(state)
instance_key = mapper._identity_key_from_state(state)
- if not postupdate and not has_identity and instance_key in uowtransaction.uow.identity_map:
- existing = uowtransaction.uow.identity_map[instance_key]._state
+ if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map:
+ instance = uowtransaction.session.identity_map[instance_key]
+ existing = attributes.instance_state(instance)
if not uowtransaction.is_deleted(existing):
- raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing)))
+ raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing)))
if self.__should_log_debug:
self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing)))
uowtransaction.set_row_switch(existing)
insert = []
update = []
- for state, connection, has_identity in tups:
- mapper = _state_mapper(state)
+ for state, mapper, connection, has_identity in tups:
if table not in mapper._pks_by_table:
continue
pks = mapper._pks_by_table[table]
if self.__should_log_debug:
self.__log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key)))
- isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity
+ isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity
params = {}
value_params = {}
hasdata = False
pks = mapper._pks_by_table[table]
def comparator(a, b):
for col in pks:
- x = cmp(a[1][col._label],b[1][col._label])
+ x = cmp(a[1][col._label], b[1][col._label])
if x != 0:
return x
return 0
rows += c.rowcount
if c.supports_sane_rowcount() and rows != len(update):
- raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
+ raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
if insert:
statement = table.insert()
if not postupdate:
# call after_XXX extensions
- for state, connection, has_identity in tups:
- mapper = _state_mapper(state)
+ for state, mapper, connection, has_identity in tups:
if not has_identity:
if 'after_insert' in mapper.extension.methods:
mapper.extension.after_insert(mapper, connection, state.obj())
if deferred_props:
if self.eager_defaults:
- _instance_key = self._identity_key_from_state(state)
- state.dict['_instance_key'] = _instance_key
- uowtransaction.session.query(self)._get(_instance_key, refresh_instance=state, only_load_props=deferred_props)
+ state.key = self._identity_key_from_state(state)
+ uowtransaction.session.query(self)._get(
+ state.key, refresh_instance=state,
+ only_load_props=deferred_props)
else:
_expire_state(state, deferred_props)
if 'connection_callable' in uowtransaction.mapper_flush_opts:
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
- tups = [(state, connection_callable(self, state.obj())) for state in states]
+ tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in states]
else:
connection = uowtransaction.transaction.connection(self)
- tups = [(state, connection) for state in states]
+ tups = [(state, _state_mapper(state), connection) for state in states]
- for (state, connection) in tups:
- mapper = _state_mapper(state)
+ for state, mapper, connection in tups:
if 'before_delete' in mapper.extension.methods:
mapper.extension.before_delete(mapper, connection, state.obj())
- deleted_objects = util.Set()
table_to_mapper = {}
for mapper in self.base_mapper.polymorphic_iterator():
for t in mapper.tables:
for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True):
delete = {}
- for (state, connection) in tups:
- mapper = _state_mapper(state)
+ for state, mapper, connection in tups:
if table not in mapper._pks_by_table:
continue
params[col.key] = mapper._get_state_attr_by_column(state, col)
if mapper.version_id_col and table.c.contains_column(mapper.version_id_col):
params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col)
- # testlib.pragma exempt:__hash__
- deleted_objects.add((state, connection))
+
for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
def comparator(a, b):
for col in mapper._pks_by_table[table]:
- x = cmp(a[col.key],b[col.key])
+ x = cmp(a[col.key], b[col.key])
if x != 0:
return x
return 0
statement = table.delete(clause)
c = connection.execute(statement, del_objects)
if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects):
- raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
+ raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
- for state, connection in deleted_objects:
- mapper = _state_mapper(state)
+ for state, mapper, connection in tups:
if 'after_delete' in mapper.extension.methods:
mapper.extension.after_delete(mapper, connection, state.obj())
visitables = [(self.__props.itervalues(), 'property', state)]
while visitables:
- iterator,item_type,parent_state = visitables[-1]
+ iterator, item_type, parent_state = visitables[-1]
try:
if item_type == 'property':
prop = iterator.next()
except StopIteration:
visitables.pop()
- def _instance(self, context, row, result=None, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None):
- if not extension:
- extension = self.extension
+ def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None):
+ pk_cols = self.primary_key
- if 'translate_row' in extension.methods:
- ret = extension.translate_row(self, context, row)
- if ret is not EXT_CONTINUE:
- row = ret
-
- if polymorphic_from:
- # if we are called from a base mapper doing a polymorphic load, figure out what tables,
- # if any, will need to be "post-fetched" based on the tables present in the row,
- # or from the options set up on the query
- if ('polymorphic_fetch', self) not in context.attributes:
- if self in context.query._with_polymorphic:
- context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [])
- else:
- context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [t for t in self.tables if t not in polymorphic_from.tables])
-
- elif not refresh_instance and self.polymorphic_on:
- discriminator = row[self.polymorphic_on]
- if discriminator is not None:
- try:
- mapper = self.polymorphic_map[discriminator]
- except KeyError:
- raise exceptions.AssertionError("No such polymorphic_identity %r is defined" % discriminator)
- if mapper is not self:
- return mapper._instance(context, row, result=result, polymorphic_from=self)
-
- # determine identity key
- if refresh_instance:
- try:
- identitykey = refresh_instance.dict['_instance_key']
- except KeyError:
- # super-rare condition; a refresh is being called
- # on a non-instance-key instance; this is meant to only
- # occur wihtin a flush()
- identitykey = self._identity_key_from_state(refresh_instance)
+ if polymorphic_from or refresh_instance:
+ polymorphic_on = None
else:
- identitykey = self.identity_key_from_row(row)
-
- session_identity_map = context.session.identity_map
+ polymorphic_on = self.polymorphic_on
+ polymorphic_instances = util.PopulateDict(self.__configure_subclass_mapper(context, path, adapter))
- if identitykey in session_identity_map:
- instance = session_identity_map[identitykey]
- state = instance._state
-
- if self.__should_log_debug:
- self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey)))
-
- isnew = state.runid != context.runid
- currentload = not isnew
-
- if not currentload and context.version_check and self.version_id_col and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
- raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
- elif refresh_instance:
- # out of band refresh_instance detected (i.e. its not in the session.identity_map)
- # honor it anyway. this can happen if a _get() occurs within save_obj(), such as
- # when eager_defaults is True.
- state = refresh_instance
- instance = state.obj()
- isnew = state.runid != context.runid
- currentload = True
- else:
- if self.__should_log_debug:
- self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
+ version_id_col = self.version_id_col
- if self.allow_null_pks:
- for x in identitykey[1]:
- if x is not None:
- break
- else:
- return None
- else:
- if None in identitykey[1]:
- return None
- isnew = True
- currentload = True
-
- if 'create_instance' in extension.methods:
- instance = extension.create_instance(self, context, row, self.class_)
- if instance is EXT_CONTINUE:
- instance = attributes.new_instance(self.class_)
- else:
- attributes.manage(instance)
- else:
- instance = attributes.new_instance(self.class_)
+ if adapter:
+ pk_cols = [adapter.columns[c] for c in pk_cols]
+ if polymorphic_on:
+ polymorphic_on = adapter.columns[polymorphic_on]
+ if version_id_col:
+ version_id_col = adapter.columns[version_id_col]
+
+ identity_class, entity_name = self._identity_class, self.entity_name
+ def identity_key(row):
+ return (identity_class, tuple([row[column] for column in pk_cols]), entity_name)
- if self.__should_log_debug:
- self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey)))
+ new_populators = []
+ existing_populators = []
- state = instance._state
- instance._entity_name = self.entity_name
- instance._instance_key = identitykey
- instance._sa_session_id = context.session.hash_key
- session_identity_map[identitykey] = instance
+ def populate_state(state, row, isnew, only_load_props, **flags):
+ if not new_populators:
+ new_populators[:], existing_populators[:] = self.__populators(context, path, row, adapter)
- if currentload or context.populate_existing or self.always_refresh:
if isnew:
- state.runid = context.runid
- context.progress.add(state)
+ populators = new_populators
+ else:
+ populators = existing_populators
- if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
- self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew)
-
- else:
- # populate attributes on non-loading instances which have been expired
- # TODO: also support deferred attributes here [ticket:870]
- if state.expired_attributes:
- if state in context.partials:
- isnew = False
- attrs = context.partials[state]
- else:
- isnew = True
- attrs = state.expired_attributes.intersection(state.unmodified)
- context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs
+ if only_load_props:
+ populators = [p for p in populators if p[0] in only_load_props]
- if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
- self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew)
-
- if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
- result.append(instance)
-
- return instance
-
- def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags):
- """populate an instance from a result row."""
-
- snapshot = selectcontext.path + (self,)
- # retrieve a set of "row population" functions derived from the MapperProperties attached
- # to this Mapper. These are keyed in the select context based primarily off the
- # "snapshot" of the stack, which represents a path from the lead mapper in the query to this one,
- # including relation() names. the key also includes "self", and allows us to distinguish between
- # other mappers within our inheritance hierarchy
- (new_populators, existing_populators) = selectcontext.attributes.get(('populators', self, snapshot, ispostselect), (None, None))
- if new_populators is None:
- # no populators; therefore this is the first time we are receiving a row for
- # this result set. issue create_row_processor() on all MapperProperty objects
- # and cache in the select context.
- new_populators = []
- existing_populators = []
- post_processors = []
- for prop in self.__props.values():
- (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row)
- if newpop:
- new_populators.append((prop.key, newpop))
- if existingpop:
- existing_populators.append((prop.key, existingpop))
- if post_proc:
- post_processors.append(post_proc)
-
- # install a post processor for immediate post-load of joined-table inheriting mappers
- poly_select_loader = self._get_poly_select_loader(selectcontext, row)
- if poly_select_loader:
- post_processors.append(poly_select_loader)
-
- selectcontext.attributes[('populators', self, snapshot, ispostselect)] = (new_populators, existing_populators)
- selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors
-
- if isnew or ispostselect:
- populators = new_populators
- else:
- populators = existing_populators
+ for key, populator in populators:
+ populator(state, row, isnew=isnew, **flags)
- if only_load_props:
- populators = [p for p in populators if p[0] in only_load_props]
+ session_identity_map = context.session.identity_map
- for (key, populator) in populators:
- selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
+ if not extension:
+ extension = self.extension
- if self.non_primary:
- selectcontext.attributes[('populating_mapper', instance._state)] = self
+ translate_row = 'translate_row' in extension.methods
+ create_instance = 'create_instance' in extension.methods
+ populate_instance = 'populate_instance' in extension.methods
+ append_result = 'append_result' in extension.methods
+ populate_existing = context.populate_existing or self.always_refresh
+
+ def _instance(row, result):
+ if translate_row:
+ ret = extension.translate_row(self, context, row)
+ if ret is not EXT_CONTINUE:
+ row = ret
+
+ if polymorphic_on:
+ discriminator = row[polymorphic_on]
+ if discriminator is not None:
+ _instance = polymorphic_instances[discriminator]
+ if _instance:
+ return _instance(row, result)
+
+ # determine identity key
+ if refresh_instance:
+ # TODO: refresh_instance seems to be named wrongly -- it is always an instance state.
+ refresh_state = refresh_instance
+ identitykey = refresh_state.key
+ if identitykey is None:
+ # super-rare condition; a refresh is being called
+ # on a non-instance-key instance; this is meant to only
+ # occur within a flush()
+ identitykey = self._identity_key_from_state(refresh_state)
+ else:
+ identitykey = identity_key(row)
- def _post_instance(self, selectcontext, state, **kwargs):
- post_processors = selectcontext.attributes[('post_processors', self, None)]
- for p in post_processors:
- p(state.obj(), **kwargs)
+ if identitykey in session_identity_map:
+ instance = session_identity_map[identitykey]
+ state = attributes.instance_state(instance)
- def _get_poly_select_loader(self, selectcontext, row):
- """set up attribute loaders for 'select' and 'deferred' polymorphic loading.
+ if self.__should_log_debug:
+ self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey)))
+
+ isnew = state.runid != context.runid
+ currentload = not isnew
+ loaded_instance = False
+
+ if not currentload and version_id_col and context.version_check and self._get_state_attr_by_column(state, self.version_id_col) != row[version_id_col]:
+ raise exc.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (state_str(state), self._get_state_attr_by_column(state, self.version_id_col), row[version_id_col]))
+ elif refresh_instance:
+ # out of band refresh_instance detected (i.e. its not in the session.identity_map)
+ # honor it anyway. this can happen if a _get() occurs within save_obj(), such as
+ # when eager_defaults is True.
+ state = refresh_instance
+ instance = state.obj()
+ isnew = state.runid != context.runid
+ currentload = True
+ loaded_instance = False
+ else:
+ if self.__should_log_debug:
+ self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
- this loading uses a second SELECT statement to load additional tables,
- either immediately after loading the main table or via a deferred attribute trigger.
- """
+ if self.allow_null_pks:
+ for x in identitykey[1]:
+ if x is not None:
+ break
+ else:
+ return None
+ else:
+ if None in identitykey[1]:
+ return None
+ isnew = True
+ currentload = True
+ loaded_instance = True
+
+ if create_instance:
+ instance = extension.create_instance(self, context, row, self.class_)
+ if instance is EXT_CONTINUE:
+ instance = self.class_manager.new_instance()
+ else:
+ manager = attributes.manager_for_cls(instance.__class__)
+ # TODO: if manager is None, raise a friendly error about
+ # returning instances of unmapped types
+ manager.setup_instance(instance)
+ else:
+ instance = self.class_manager.new_instance()
- (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None))
+ if self.__should_log_debug:
+ self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey)))
- if hosted_mapper is None or not needs_tables:
- return
+ state = attributes.instance_state(instance)
+ state.entity_name = self.entity_name
+ state.key = identitykey
+ # manually adding instance to session. for a complete add,
+ # session._finalize_loaded() must be called.
+ state.session_id = context.session.hash_key
+ session_identity_map.add(state)
- cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables)
- statement = sql.select(needs_tables, cond, use_labels=True)
+ if currentload or populate_existing:
+ if isnew:
+ state.runid = context.runid
+ context.progress.add(state)
- if hosted_mapper.polymorphic_fetch == 'select':
- def post_execute(instance, **flags):
- if self.__should_log_debug:
- self.__log_debug("Post query loading instance " + instance_str(instance))
+ if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+ populate_state(state, row, isnew, only_load_props)
- identitykey = self.identity_key_from_instance(instance)
-
- only_load_props = flags.get('only_load_props', None)
+ else:
+ # populate attributes on non-loading instances which have been expired
+ # TODO: also support deferred attributes here [ticket:870]
+ # TODO: apply eager loads to un-lazy loaded collections ?
+ # we might want to create an expanded form of 'state.expired_attributes' which includes deferred/un-lazy loaded
+ if state.expired_attributes:
+ if state in context.partials:
+ isnew = False
+ attrs = context.partials[state]
+ else:
+ isnew = True
+ attrs = state.expired_attributes.intersection(state.unmodified)
+ context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs
- params = {}
- for c, bind in param_names:
- params[bind] = self._get_attr_by_column(instance, c)
- row = selectcontext.session.connection(self).execute(statement, params).fetchone()
- self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True, only_load_props=only_load_props)
- return post_execute
- elif hosted_mapper.polymorphic_fetch == 'deferred':
- from sqlalchemy.orm.strategies import DeferredColumnLoader
-
- def post_execute(instance, **flags):
- def create_statement(instance):
- params = {}
- for (c, bind) in param_names:
- # use the "committed" (database) version to get query column values
- params[bind] = self._get_committed_attr_by_column(instance, c)
- return (statement, params)
-
- props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__]
- keys = [p.key for p in props]
-
- only_load_props = flags.get('only_load_props', None)
- if only_load_props:
- keys = util.Set(keys).difference(only_load_props)
- props = [p for p in props if p.key in only_load_props]
-
- for prop in props:
- strategy = prop._get_strategy(DeferredColumnLoader)
- instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement))
- return post_execute
- else:
- return None
+ if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+ populate_state(state, row, isnew, attrs, instancekey=identitykey)
+
+ if result is not None and (not append_result or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
+ result.append(instance)
- def _deferred_inheritance_condition(self, base_mapper, needs_tables):
- base_mapper = base_mapper.primary_mapper()
+ if loaded_instance:
+ state._run_on_load(instance)
+
+ return instance
+ return _instance
+
+ def __populators(self, context, path, row, adapter):
+ new_populators, existing_populators = [], []
+ for prop in self.__props.values():
+ newpop, existingpop = prop.create_row_processor(context, path, self, row, adapter)
+ if newpop:
+ new_populators.append((prop.key, newpop))
+ if existingpop:
+ existing_populators.append((prop.key, existingpop))
+ return new_populators, existing_populators
+
+ def __configure_subclass_mapper(self, context, path, adapter):
+ def configure_subclass_mapper(discriminator):
+ try:
+ mapper = self.polymorphic_map[discriminator]
+ except KeyError:
+ raise AssertionError("No such polymorphic_identity %r is defined" % discriminator)
+ if mapper is self:
+ return None
+ return mapper._instance_processor(context, path, adapter, polymorphic_from=self)
+ return configure_subclass_mapper
+
+ def _optimized_get_statement(self, state, attribute_names):
+ props = self.__props
+ tables = util.Set([props[key].parent.local_table for key in attribute_names])
+ if self.base_mapper.local_table in tables:
+ return None
def visit_binary(binary):
leftcol = binary.left
rightcol = binary.right
if leftcol is None or rightcol is None:
return
- if leftcol.table not in needs_tables:
- binary.left = sql.bindparam(None, None, type_=binary.right.type)
- param_names.append((leftcol, binary.left))
- elif rightcol not in needs_tables:
- binary.right = sql.bindparam(None, None, type_=binary.right.type)
- param_names.append((rightcol, binary.right))
+
+ if leftcol.table not in tables:
+ binary.left = sql.bindparam(None, self._get_committed_state_attr_by_column(state, leftcol), type_=binary.right.type)
+ elif rightcol.table not in tables:
+ binary.right = sql.bindparam(None, self._get_committed_state_attr_by_column(state, rightcol), type_=binary.right.type)
allconds = []
- param_names = []
- for mapper in self.iterate_to_root():
- if mapper is base_mapper:
- break
- allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
+ start = False
+ for mapper in util.reversed(list(self.iterate_to_root())):
+ if mapper.local_table in tables:
+ start = True
+ if start:
+ allconds.append(visitors.cloned_traverse(mapper.inherit_condition, {}, {'binary':visit_binary}))
+
+ cond = sql.and_(*allconds)
+ return sql.select(tables, cond, use_labels=True)
+
+Mapper.logger = log.class_logger(Mapper)
+
- return sql.and_(*allconds), param_names
+def _event_on_init(state, instance, args, kwargs):
+ """Trigger mapper compilation and run init_instance hooks."""
-Mapper.logger = logging.class_logger(Mapper)
+ instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
+ # compile() always compiles all mappers
+ instrumenting_mapper.compile()
+ if 'init_instance' in instrumenting_mapper.extension.methods:
+ instrumenting_mapper.extension.init_instance(
+ instrumenting_mapper, instrumenting_mapper.class_,
+ state.manager.events.original_init,
+ instance, args, kwargs)
+def _event_on_init_failure(state, instance, args, kwargs):
+ """Run init_failed hooks."""
-object_session = None
+ instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
+ if 'init_failed' in instrumenting_mapper.extension.methods:
+ util.warn_exception(
+ instrumenting_mapper.extension.init_failed,
+ instrumenting_mapper, instrumenting_mapper.class_,
+ state.manager.events.original_init, instance, args, kwargs)
-def _load_scalar_attributes(instance, attribute_names):
- mapper = object_mapper(instance)
- global object_session
- if not object_session:
- from sqlalchemy.orm.session import object_session
- session = object_session(instance)
+def _legacy_descriptors():
+ """Build compatibility descriptors mapping legacy to InstanceState.
+
+ These are slated for removal in 0.5. They were never part of the
+ official public API but were suggested as temporary workarounds in a
+ number of mailing list posts. Permanent and public solutions for those
+ needs should be available now. Consult the applicable mailing list
+ threads for details.
+
+ """
+ def _instance_key(self):
+ state = attributes.instance_state(self)
+ if state.key is not None:
+ return state.key
+ else:
+ raise AttributeError("_instance_key")
+ _instance_key = util.deprecated(None, False)(_instance_key)
+ _instance_key = property(_instance_key)
+
+ def _sa_session_id(self):
+ state = attributes.instance_state(self)
+ if state.session_id is not None:
+ return state.session_id
+ else:
+ raise AttributeError("_sa_session_id")
+ _sa_session_id = util.deprecated(None, False)(_sa_session_id)
+ _sa_session_id = property(_sa_session_id)
+
+ def _entity_name(self):
+ state = attributes.instance_state(self)
+ if state.entity_name is attributes.NO_ENTITY_NAME:
+ return None
+ else:
+ return state.entity_name
+ _entity_name = util.deprecated(None, False)(_entity_name)
+ _entity_name = property(_entity_name)
+
+ return dict(locals())
+_legacy_descriptors = _legacy_descriptors()
+
+def _load_scalar_attributes(state, attribute_names):
+ mapper = _state_mapper(state)
+ session = _state_session(state)
if not session:
- try:
- session = mapper.get_session()
- except exceptions.InvalidRequestError:
- raise exceptions.UnboundExecutionError("Instance %s is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed" % (instance.__class__))
-
- state = instance._state
- if '_instance_key' in state.dict:
- identity_key = state.dict['_instance_key']
- shouldraise = True
- else:
- # if instance is pending, a refresh operation may not complete (even if PK attributes are assigned)
- shouldraise = False
- identity_key = mapper._identity_key_from_state(state)
-
- if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None and shouldraise:
- raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
+ raise sa_exc.UnboundExecutionError("Instance %s is not bound to a Session; attribute refresh operation cannot proceed" % (state_str(state)))
+
+ has_key = _state_has_identity(state)
+
+ result = False
+ if mapper.inherits and not mapper.concrete:
+ statement = mapper._optimized_get_statement(state, attribute_names)
+ if statement:
+ result = session.query(mapper).from_statement(statement)._get(None, only_load_props=attribute_names, refresh_instance=state)
+
+ if result is False:
+ if has_key:
+ identity_key = state.key
+ else:
+ identity_key = mapper._identity_key_from_state(state)
+ result = session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names)
+ # if instance is pending, a refresh operation may not complete (even if PK attributes are assigned)
+ if has_key and result is None:
+ raise exc.ObjectDeletedError("Instance '%s' has been deleted." % state_str(state))
"""MapperProperty implementations.
-This is a private module which defines the behavior of
-invidual ORM-mapped attributes.
+This is a private module which defines the behavior of invidual ORM-mapped
+attributes.
+
"""
-from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, find_columns
-from sqlalchemy.sql import visitors, operators, ColumnElement
-from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
-from sqlalchemy.orm import session as sessionlib
-from sqlalchemy.orm.mapper import _class_to_mapper
-from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
-from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY
-from sqlalchemy.exceptions import ArgumentError
+from sqlalchemy import sql, util, log
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs
+from sqlalchemy.sql import operators, ColumnElement, expression
+from sqlalchemy.orm import mapper, strategies, attributes, dependency, \
+ object_mapper, session as sessionlib
+from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate
+from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, \
+ MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY
__all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty',
'ComparableProperty', 'PropertyLoader', 'BackRef')
appears across each table.
"""
- self.columns = list(columns)
+ self.columns = [expression._labeled(c) for c in columns]
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
self.comparator = ColumnProperty.ColumnComparator(self)
+ util.set_creation_order(self)
if self.deferred:
self.strategy_class = strategies.DeferredColumnLoader
else:
self.strategy_class = strategies.ColumnLoader
- # sanity check
- for col in columns:
- if not isinstance(col, ColumnElement):
- raise ArgumentError('column_property() must be given a ColumnElement as its argument. Try .label() or .as_scalar() for Selectables to fix this.')
def do_init(self):
super(ColumnProperty, self).do_init()
return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
def getattr(self, state, column):
- return getattr(state.class_, self.key).impl.get(state)
+ return state.get_impl(self.key).get(state)
def getcommitted(self, state, column):
- return getattr(state.class_, self.key).impl.get_committed_value(state)
+ return state.get_impl(self.key).get_committed_value(state)
def setattr(self, state, value, column):
- getattr(state.class_, self.key).impl.set(state, value, None)
+ state.get_impl(self.key).set(state, value, None)
def merge(self, session, source, dest, dont_load, _recursive):
- value = attributes.get_as_list(source._state, self.key, passive=True)
+ value = attributes.instance_state(source).value_as_iterable(
+ self.key, passive=True)
if value:
setattr(dest, self.key, value[0])
else:
- # TODO: lazy callable should merge to the new instance
- dest._state.expire_attributes([self.key])
+ attributes.instance_state(dest).expire_attributes([self.key])
def get_col_value(self, column, value):
return value
class ColumnComparator(PropComparator):
- def clause_element(self):
- return self.prop.columns[0]
-
+ def __clause_element__(self):
+ return self.prop.columns[0]._annotate({"parententity": self.prop.parent})
+ __clause_element__ = util.cache_decorator(__clause_element__)
+
def operate(self, op, *other, **kwargs):
- return op(self.prop.columns[0], *other, **kwargs)
+ return op(self.__clause_element__(), *other, **kwargs)
def reverse_operate(self, op, other, **kwargs):
- col = self.prop.columns[0]
+ col = self.__clause_element__()
return op(col._bind_param(other), col, **kwargs)
-ColumnProperty.logger = logging.class_logger(ColumnProperty)
+ def __str__(self):
+ return str(self.parent.class_.__name__) + "." + self.key
+
+ColumnProperty.logger = log.class_logger(ColumnProperty)
class CompositeProperty(ColumnProperty):
"""subclasses ColumnProperty to provide composite type support."""
super(CompositeProperty, self).__init__(*columns, **kwargs)
self.composite_class = class_
self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
+ self.strategy_class = strategies.CompositeColumnLoader
def do_init(self):
super(ColumnProperty, self).do_init()
return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
def getattr(self, state, column):
- obj = getattr(state.class_, self.key).impl.get(state)
+ obj = state.get_impl(self.key).get(state)
return self.get_col_value(column, obj)
def getcommitted(self, state, column):
- obj = getattr(state.class_, self.key).impl.get_committed_value(state)
+ obj = state.get_impl(self.key).get_committed_value(state)
return self.get_col_value(column, obj)
def setattr(self, state, value, column):
# TODO: test coverage for this method
- obj = getattr(state.class_, self.key).impl.get(state)
+ obj = state.get_impl(self.key).get(state)
if obj is None:
obj = self.composite_class(*[None for c in self.columns])
- getattr(state.class_, self.key).impl.set(state, obj, None)
+ state.get_impl(self.key).set(state, obj, None)
for a, b in zip(self.columns, value.__composite_values__()):
if a is column:
return b
class Comparator(PropComparator):
+ def __clause_element__(self):
+ return expression.ClauseList(*self.prop.columns)
+
def __eq__(self, other):
if other is None:
return sql.and_(*[a==None for a in self.prop.columns])
zip(self.prop.columns,
other.__composite_values__())])
+ def __str__(self):
+ return str(self.parent.class_.__name__) + "." + self.key
+
class SynonymProperty(MapperProperty):
def __init__(self, name, map_column=None, descriptor=None):
self.name = name
- self.map_column=map_column
+ self.map_column = map_column
self.descriptor = descriptor
+ util.set_creation_order(self)
- def setup(self, querycontext, **kwargs):
+ def setup(self, context, entity, path, adapter, **kwargs):
pass
- def create_row_processor(self, selectcontext, mapper, row):
- return (None, None, None)
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ return (None, None)
def do_init(self):
class_ = self.parent.class_
return s
return getattr(obj, self.name)
self.descriptor = SynonymProp()
- sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator)
+ sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent)
def merge(self, session, source, dest, _recursive):
pass
-SynonymProperty.logger = logging.class_logger(SynonymProperty)
-
+SynonymProperty.logger = log.class_logger(SynonymProperty)
class ComparableProperty(MapperProperty):
"""Instruments a Python property for use in query expressions."""
def __init__(self, comparator_factory, descriptor=None):
self.descriptor = descriptor
self.comparator = comparator_factory(self)
+ util.set_creation_order(self)
def do_init(self):
"""Set up a proxy to the unmanaged descriptor."""
useobject=False,
comparator=self.comparator)
- def setup(self, querycontext, **kwargs):
+ def setup(self, context, entity, path, adapter, **kwargs):
pass
- def create_row_processor(self, selectcontext, mapper, row):
- return (None, None, None)
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ return (None, None)
class PropertyLoader(StrategizedProperty):
of items that correspond to a related database table.
"""
- def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, passive_updates=True, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None, _local_remote_pairs=None):
+ def __init__(self, argument,
+ secondary=None, primaryjoin=None,
+ secondaryjoin=None, entity_name=None,
+ foreign_keys=None,
+ uselist=None,
+ order_by=False,
+ backref=None,
+ _is_backref=False,
+ post_update=False,
+ cascade=None,
+ viewonly=False, lazy=True,
+ collection_class=None, passive_deletes=False,
+ passive_updates=True, remote_side=None,
+ enable_typechecks=True, join_depth=None,
+ strategy_class=None, _local_remote_pairs=None):
+
self.uselist = uselist
self.argument = argument
self.entity_name = entity_name
self.viewonly = viewonly
self.lazy = lazy
self.foreign_keys = util.to_set(foreign_keys)
- self._legacy_foreignkey = util.to_set(foreignkey)
- if foreignkey:
- util.warn_deprecated('foreignkey option is deprecated; see docs for details')
self.collection_class = collection_class
self.passive_deletes = passive_deletes
self.passive_updates = passive_updates
self.comparator = PropertyLoader.Comparator(self)
self.join_depth = join_depth
self._arg_local_remote_pairs = _local_remote_pairs
+ self.__join_cache = {}
+ util.set_creation_order(self)
if strategy_class:
self.strategy_class = strategy_class
if cascade is not None:
self.cascade = CascadeOptions(cascade)
else:
- if private:
- util.warn_deprecated('private option is deprecated; see docs for details')
- self.cascade = CascadeOptions("all, delete-orphan")
- else:
- self.cascade = CascadeOptions("save-update, merge")
+ self.cascade = CascadeOptions("save-update, merge")
if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade):
- raise exceptions.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade")
+ raise sa_exc.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade")
- self.association = association
- if association:
- util.warn_deprecated('association option is deprecated; see docs for details')
self.order_by = order_by
- self.attributeext=attributeext
+
if isinstance(backref, str):
# propigate explicitly sent primary/secondary join conditions to the BackRef object if
# just a string was sent
self.backref = BackRef(backref, primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, passive_updates=self.passive_updates)
else:
self.backref = backref
- self.is_backref = is_backref
-
+ self._is_backref = _is_backref
+
class Comparator(PropComparator):
def __init__(self, prop, of_type=None):
self.prop = self.property = prop
if of_type:
self._of_type = _class_to_mapper(of_type)
+ def parententity(self):
+ return self.prop.parent
+ parententity = property(parententity)
+
+ def __clause_element__(self):
+ return self.prop.parent._with_polymorphic_selectable
+
def of_type(self, cls):
return PropertyLoader.Comparator(self.prop, cls)
return self.prop._optimized_compare(None)
elif self.prop.uselist:
if not hasattr(other, '__iter__'):
- raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object. Use contains().")
+ raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object. Use contains().")
else:
j = self.prop.primaryjoin
if self.prop.secondaryjoin:
else:
return self.prop._optimized_compare(other)
- def _join_and_criterion(self, criterion=None, **kwargs):
+ def __criterion_exists(self, criterion=None, **kwargs):
if getattr(self, '_of_type', None):
target_mapper = self._of_type
- to_selectable = target_mapper._with_polymorphic_selectable() #mapped_table
+ to_selectable = target_mapper._with_polymorphic_selectable
else:
to_selectable = None
- pj, sj, source, dest, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable)
+ pj, sj, source, dest, secondary, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable)
for k in kwargs:
- crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+ crit = self.prop.mapper.class_manager.get_inst(k) == kwargs[k]
if criterion is None:
criterion = crit
else:
criterion = criterion & crit
if sj:
- j = pj & sj
+ j = _orm_annotate(pj) & sj
else:
- j = pj
+ j = _orm_annotate(pj, exclude=self.prop.remote_side)
if criterion and target_adapter:
+ # limit this adapter to annotated only?
criterion = target_adapter.traverse(criterion)
- return j, criterion, dest
+ # only have the "joined left side" of what we return be subject to Query adaption. The right
+ # side of it is used for an exists() subquery and should not correlate or otherwise reach out
+ # to anything in the enclosing query.
+ if criterion:
+ criterion = criterion._annotate({'_halt_adapt': True})
+ return sql.exists([1], j & criterion, from_obj=dest).correlate(source)
def any(self, criterion=None, **kwargs):
if not self.prop.uselist:
- raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
- j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs)
+ raise sa_exc.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
- return sql.exists([1], j & criterion, from_obj=from_obj)
+ return self.__criterion_exists(criterion, **kwargs)
def has(self, criterion=None, **kwargs):
if self.prop.uselist:
- raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().")
- j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs)
-
- return sql.exists([1], j & criterion, from_obj=from_obj)
+ raise sa_exc.InvalidRequestError("'has()' not implemented for collections. Use any().")
+ return self.__criterion_exists(criterion, **kwargs)
def contains(self, other):
if not self.prop.uselist:
- raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==")
+ raise sa_exc.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==")
clause = self.prop._optimized_compare(other)
if self.prop.secondaryjoin:
- clause.negation_clause = self._negated_contains_or_equals(other)
+ clause.negation_clause = self.__negated_contains_or_equals(other)
return clause
- def _negated_contains_or_equals(self, other):
+ def __negated_contains_or_equals(self, other):
criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
- j, criterion, from_obj = self._join_and_criterion(criterion)
- return ~sql.exists([1], j & criterion, from_obj=from_obj)
+ return ~self.__criterion_exists(criterion)
def __ne__(self, other):
if other is None:
return self.has()
if self.prop.uselist and not hasattr(other, '__iter__'):
- raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+ raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object")
- return self._negated_contains_or_equals(other)
+ return self.__negated_contains_or_equals(other)
def compare(self, op, value, value_is_parent=False):
if op == operators.eq:
return op(self.comparator, value)
def _optimized_compare(self, value, value_is_parent=False):
+ if value is not None:
+ value = attributes.instance_state(value)
return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent)
- def private(self):
- return self.cascade.delete_orphan
- private = property(private)
-
def __str__(self):
- return str(self.parent.class_.__name__) + "." + self.key + " (" + str(self.mapper.class_.__name__) + ")"
+ return str(self.parent.class_.__name__) + "." + self.key
def merge(self, session, source, dest, dont_load, _recursive):
if not dont_load and self._reverse_property and (source, self._reverse_property) in _recursive:
return
-
+
+ source_state = attributes.instance_state(source)
+ dest_state = attributes.instance_state(dest)
+
if not "merge" in self.cascade:
- dest._state.expire_attributes([self.key])
+ dest_state.expire_attributes([self.key])
return
- instances = attributes.get_as_list(source._state, self.key, passive=True)
+ instances = source_state.value_as_iterable(self.key, passive=True)
+
if not instances:
return
-
+
if self.uselist:
dest_list = []
for current in instances:
if obj is not None:
dest_list.append(obj)
if dont_load:
- coll = attributes.init_collection(dest, self.key)
+ coll = attributes.init_collection(dest_state, self.key)
for c in dest_list:
coll.append_without_event(c)
else:
- getattr(dest.__class__, self.key).impl._set_iterable(dest._state, dest_list)
+ getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list)
else:
current = instances[0]
if current is not None:
return
passive = type_ != 'delete' or self.passive_deletes
mapper = self.mapper.primary_mapper()
- instances = attributes.get_as_list(state, self.key, passive=passive)
+ instances = state.value_as_iterable(self.key, passive=passive)
if instances:
for c in instances:
if c is not None and c not in visited_instances and (halt_on is None or not halt_on(c)):
if not isinstance(c, self.mapper.class_):
- raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
+ raise AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
visited_instances.add(c)
# cascade using the mapper local to this object, so that its individual properties are located
instance_mapper = object_mapper(c, entity_name=mapper.entity_name)
- yield (c, instance_mapper, c._state)
+ yield (c, instance_mapper, attributes.instance_state(c))
def _get_target_class(self):
"""Return the target class of the relation, even if the
# accept a callable to suit various deferred-configurational schemes
self.mapper = mapper.class_mapper(self.argument(), entity_name=self.entity_name, compile=False)
else:
- raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument)))
+ raise sa_exc.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument)))
+ assert isinstance(self.mapper, mapper.Mapper), self.mapper
if not self.parent.concrete:
for inheriting in self.parent.iterate_to_root():
if self.cascade.delete_orphan:
if self.parent.class_ is self.mapper.class_:
- raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade "
+ raise sa_exc.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade "
"rule on a self-referential relationship. "
"You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_))
def __determine_joins(self):
if self.secondaryjoin is not None and self.secondary is None:
- raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
+ raise sa_exc.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
# if join conditions were not specified, figure them out based on foreign keys
def _search_for_join(mapper, table):
is a join."""
try:
return sql.join(mapper.local_table, table)
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
return sql.join(mapper.mapped_table, table)
try:
else:
if self.primaryjoin is None:
self.primaryjoin = _search_for_join(self.parent, self.target).onclause
- except exceptions.ArgumentError, e:
- raise exceptions.ArgumentError("Could not determine join condition between parent/child tables on relation %s. "
+ except sa_exc.ArgumentError, e:
+ raise sa_exc.ArgumentError("Could not determine join condition between parent/child tables on relation %s. "
"Specify a 'primaryjoin' expression. If this is a many-to-many relation, 'secondaryjoin' is needed as well." % (self))
def __determine_fks(self):
- if self._legacy_foreignkey and not self._refers_to_parent_table():
- self.foreign_keys = self._legacy_foreignkey
-
arg_foreign_keys = self.foreign_keys
if self._arg_local_remote_pairs:
if not arg_foreign_keys:
- raise exceptions.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument")
+ raise sa_exc.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument")
self.foreign_keys = util.OrderedSet(arg_foreign_keys)
self._opposite_side = util.OrderedSet()
for l, r in self._arg_local_remote_pairs:
if not eq_pairs:
if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
- raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. "
+ raise sa_exc.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. "
"For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.primaryjoin, self)
)
else:
if arg_foreign_keys:
- raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
+ raise sa_exc.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
"Specify _local_remote_pairs=[(local, remote), (local, remote), ...] to explicitly establish the local/remote column pairs." % (self.primaryjoin, self))
else:
- raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
+ raise sa_exc.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
"Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self))
self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs])
if not sq_pairs:
if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
- raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for secondaryjoin condition '%s' on relation %s. "
+ raise sa_exc.ArgumentError("Could not locate any equated, locally mapped column pairs for secondaryjoin condition '%s' on relation %s. "
"For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.secondaryjoin, self)
)
else:
- raise exceptions.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. "
+ raise sa_exc.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. "
"Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.secondaryjoin, self))
self.foreign_keys.update([r for l, r in sq_pairs])
def __determine_remote_side(self):
if self._arg_local_remote_pairs:
if self.remote_side:
- raise exceptions.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.")
+ raise sa_exc.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.")
if self.direction is MANYTOONE:
eq_pairs = [(r, l) for l, r in self._arg_local_remote_pairs]
else:
if self.direction is ONETOMANY:
for l in self.local_side:
if not self.__col_is_part_of_mappings(l):
- raise exceptions.ArgumentError("Local column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should compare against." % (l, self.parent))
+ raise sa_exc.ArgumentError("Local column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should compare against." % (l, self.parent))
elif self.direction is MANYTOONE:
for r in self.remote_side:
if not self.__col_is_part_of_mappings(r):
- raise exceptions.ArgumentError("Remote column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should bind." % (r, self.mapper))
+ raise sa_exc.ArgumentError("Remote column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should bind." % (r, self.mapper))
def __determine_direction(self):
"""Determine our *direction*, i.e. do we represent one to
# for a self referential mapper, if the "foreignkey" is a single or composite primary key,
# then we are "many to one", since the remote site of the relationship identifies a singular entity.
# otherwise we are "one to many".
- if self._legacy_foreignkey:
- for f in self._legacy_foreignkey:
- if not f.primary_key:
- self.direction = ONETOMANY
- else:
- self.direction = MANYTOONE
- elif self._arg_local_remote_pairs:
+ if self._arg_local_remote_pairs:
remote = util.Set([r for l, r in self._arg_local_remote_pairs])
if self.foreign_keys.intersection(remote):
self.direction = ONETOMANY
manytoone = [c for c in self.foreign_keys if parenttable.c.contains_column(c)]
if not onetomany and not manytoone:
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"Can't determine relation direction for relationship '%s' "
"- foreign key columns are present in neither the "
"parent nor the child's mapped tables" %(str(self)))
self.direction = MANYTOONE
break
else:
- raise exceptions.ArgumentError(
+ raise sa_exc.ArgumentError(
"Can't determine relation direction for relationship '%s' "
"- foreign key columns are present in both the parent and "
"the child's mapped tables. Specify 'foreign_keys' "
"argument." % (str(self)))
def _post_init(self):
- if logging.is_info_enabled(self.logger):
+ if log.is_info_enabled(self.logger):
self.logger.info(str(self) + " setup primary join %s" % self.primaryjoin)
self.logger.info(str(self) + " setup secondary join %s" % self.secondaryjoin)
self.logger.info(str(self) + " synchronize pairs [%s]" % ",".join(["(%s => %s)" % (l, r) for l, r in self.synchronize_pairs]))
# primary property handler, set up class attributes
if self.is_primary():
- # if a backref name is defined, set up an extension to populate
- # attributes in the other direction
- if self.backref is not None:
- self.attributeext = self.backref.get_extension()
-
if self.backref is not None:
self.backref.compile(self)
elif not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False):
- raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
+ raise sa_exc.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
super(PropertyLoader, self).do_init()
return self.mapper.common_parent(self.parent)
def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None):
+ key = util.WeakCompositeKey(source_polymorphic, source_selectable, dest_polymorphic, dest_selectable)
+ try:
+ return self.__join_cache[key]
+ except KeyError:
+ pass
+
if source_selectable is None:
if source_polymorphic and self.parent.with_polymorphic:
- source_selectable = self.parent._with_polymorphic_selectable()
- else:
- source_selectable = None
+ source_selectable = self.parent._with_polymorphic_selectable
+
+ aliased = False
if dest_selectable is None:
if dest_polymorphic and self.mapper.with_polymorphic:
- dest_selectable = self.mapper._with_polymorphic_selectable()
+ dest_selectable = self.mapper._with_polymorphic_selectable
+ aliased = True
else:
dest_selectable = self.mapper.mapped_table
- if self._is_self_referential():
+
+ if self._is_self_referential() and source_selectable is None:
+ dest_selectable = dest_selectable.alias()
+ aliased = True
+ else:
+ aliased = True
+
+ aliased = aliased or bool(source_selectable)
+
+ primaryjoin, secondaryjoin, secondary = self.primaryjoin, self.secondaryjoin, self.secondary
+ if aliased:
+ if secondary:
+ secondary = secondary.alias()
+ primary_aliasizer = ClauseAdapter(secondary)
if dest_selectable:
- dest_selectable = dest_selectable.alias()
+ secondary_aliasizer = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns).chain(primary_aliasizer)
else:
- dest_selectable = self.mapper.mapped_table.alias()
-
- primaryjoin = self.primaryjoin
- if source_selectable:
- if self.direction in (ONETOMANY, MANYTOMANY):
- primaryjoin = ClauseAdapter(source_selectable, exclude=self.foreign_keys, equivalents=self.parent._equivalent_columns).traverse(primaryjoin)
+ secondary_aliasizer = primary_aliasizer
+
+ if source_selectable:
+ primary_aliasizer = ClauseAdapter(secondary).chain(ClauseAdapter(source_selectable, equivalents=self.parent._equivalent_columns))
+
+ secondaryjoin = secondary_aliasizer.traverse(secondaryjoin)
else:
- primaryjoin = ClauseAdapter(source_selectable, include=self.foreign_keys, equivalents=self.parent._equivalent_columns).traverse(primaryjoin)
+ if dest_selectable:
+ primary_aliasizer = ClauseAdapter(dest_selectable, exclude=self.local_side, equivalents=self.mapper._equivalent_columns)
+ if source_selectable:
+ primary_aliasizer.chain(ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns))
+ elif source_selectable:
+ primary_aliasizer = ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns)
+
+ secondary_aliasizer = None
- secondaryjoin = self.secondaryjoin
- target_adapter = None
- if dest_selectable:
- if self.direction == ONETOMANY:
- target_adapter = ClauseAdapter(dest_selectable, include=self.foreign_keys, equivalents=self.mapper._equivalent_columns)
- elif self.direction == MANYTOMANY:
- target_adapter = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns)
- else:
- target_adapter = ClauseAdapter(dest_selectable, exclude=self.foreign_keys, equivalents=self.mapper._equivalent_columns)
- if secondaryjoin:
- secondaryjoin = target_adapter.traverse(secondaryjoin)
- else:
- primaryjoin = target_adapter.traverse(primaryjoin)
+ primaryjoin = primary_aliasizer.traverse(primaryjoin)
+ target_adapter = secondary_aliasizer or primary_aliasizer
target_adapter.include = target_adapter.exclude = None
-
- return primaryjoin, secondaryjoin, source_selectable or self.parent.local_table, dest_selectable or self.mapper.local_table, target_adapter
+ else:
+ target_adapter = None
+
+ self.__join_cache[key] = ret = (primaryjoin, secondaryjoin, (source_selectable or self.parent.local_table), (dest_selectable or self.mapper.local_table), secondary, target_adapter)
+ return ret
def _get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True):
"""deprecated. use primary_join_against(), secondary_join_against(), full_join_against()"""
- pj, sj, source, dest, adapter = self._create_joins(source_polymorphic=polymorphic_parent)
+ pj, sj, source, dest, secondarytable, adapter = self._create_joins(source_polymorphic=polymorphic_parent)
if primary and secondary:
return pj & sj
if not self.viewonly:
self._dependency_processor.register_dependencies(uowcommit)
-PropertyLoader.logger = logging.class_logger(PropertyLoader)
+PropertyLoader.logger = log.class_logger(PropertyLoader)
class BackRef(object):
"""Attached to a PropertyLoader to indicate a complementary reverse relationship.
self.key = key
self.kwargs = kwargs
self.prop = _prop
-
+ self.extension = attributes.GenericBackrefExtension(self.key)
+
def compile(self, prop):
if self.prop:
return
relation = PropertyLoader(parent, prop.secondary, pj, sj,
backref=BackRef(prop.key, _prop=prop),
- is_backref=True,
+ _is_backref=True,
**self.kwargs)
mapper._compile_property(self.key, relation);
mapper._get_property(self.key)._reverse_property = prop
else:
- raise exceptions.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper))
-
- def get_extension(self):
- """Return an attribute extension to use with this backreference."""
-
- return attributes.GenericBackrefExtension(self.key)
+ raise sa_exc.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper))
mapper.ColumnProperty = ColumnProperty
mapper.SynonymProperty = SynonymProperty
that it returns ORM-mapped objects and interacts with an ORM session, whereas
the ``Select`` construct interacts directly with the database to return
iterable result sets.
+
"""
from itertools import chain
-from sqlalchemy import sql, util, exceptions, logging
+
+from sqlalchemy import sql, util, log
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import expression, visitors, operators
-from sqlalchemy.orm import mapper, object_mapper
+from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper
+from sqlalchemy.orm.util import _state_mapper, _is_mapped_class, \
+ _is_aliased_class, _entity_descriptor, _entity_info, _class_to_mapper, \
+ _orm_columns, AliasedClass, _orm_selectable, join as orm_join, ORMAdapter
-from sqlalchemy.orm.util import _state_mapper, _class_to_mapper, _is_mapped_class, _is_aliased_class
-from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm import interfaces
-from sqlalchemy.orm import attributes
-from sqlalchemy.orm.util import AliasedClass
+__all__ = ['Query', 'QueryContext', 'aliased']
-aliased = AliasedClass
-__all__ = ['Query', 'QueryContext', 'aliased']
+aliased = AliasedClass
+def _generative(*assertions):
+ """mark a method as generative."""
+
+ def decorate(fn):
+ argspec = util.format_argspec_plus(fn)
+ run_assertions = assertions
+ code = "\n".join([
+ "def %s%s:",
+ " %r",
+ " self = self._clone()",
+ " for a in run_assertions:",
+ " a(self, %r)",
+ " fn%s",
+ " return self"
+ ]) % (fn.__name__, argspec['args'], fn.__doc__, fn.__name__, argspec['apply_pos'])
+ env = locals().copy()
+ exec code in env
+ return env[fn.__name__]
+ return decorate
class Query(object):
"""Encapsulates the object-fetching operations provided by Mappers."""
- def __init__(self, class_or_mapper, session=None, entity_name=None):
- self._session = session
-
+ def __init__(self, entities, session=None, entity_name=None):
+ self.session = session
+
self._with_options = []
self._lockmode = None
-
- self._entities = []
self._order_by = False
self._group_by = False
self._distinct = False
self._params = {}
self._yield_per = None
self._criterion = None
+ self._correlate = util.Set()
+ self._joinpoint = None
+ self._with_labels = False
self.__joinable_tables = None
self._having = None
- self._column_aggregate = None
self._populate_existing = False
self._version_check = False
self._autoflush = True
-
self._attributes = {}
self._current_path = ()
self._only_load_props = None
self._refresh_instance = None
-
- self.__init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name))
-
- def __init_mapper(self, mapper):
- """populate all instance variables derived from this Query's mapper."""
-
- self.mapper = mapper
- self.table = self._from_obj = self.mapper.mapped_table
- self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
- self._extension = self.mapper.extension
- self._aliases_head = self._aliases_tail = None
- self._alias_ids = {}
- self._joinpoint = self.mapper
- self._entities.append(_PrimaryMapperEntity(self.mapper))
- if self.mapper.with_polymorphic:
- self.__set_with_polymorphic(*self.mapper.with_polymorphic)
- else:
- self._with_polymorphic = []
-
- def __generate_alias_ids(self):
- self._alias_ids = dict([
- (k, list(v)) for k, v in self._alias_ids.iteritems()
- ])
+ self._from_obj = None
+ self._entities = []
+ self._polymorphic_adapters = {}
+ self._filter_aliases = None
+ self._from_obj_alias = None
+ self.__currenttables = util.Set()
+
+ for ent in util.to_list(entities):
+ _QueryEntity(self, ent, entity_name=entity_name)
+
+ self.__setup_aliasizers(self._entities)
+
+ def __setup_aliasizers(self, entities):
+ d = {}
+ for ent in entities:
+ for entity in ent.entities:
+ if entity not in d:
+ mapper, selectable, is_aliased_class = _entity_info(entity, ent.entity_name)
+ if not is_aliased_class and mapper.with_polymorphic:
+ with_polymorphic = mapper._with_polymorphic_mappers
+ self.__mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter(selectable, mapper._equivalent_columns))
+ adapter = None
+ elif is_aliased_class:
+ adapter = sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)
+ with_polymorphic = None
+ else:
+ with_polymorphic = adapter = None
- def __no_criterion(self, meth):
- return self.__conditional_clone(meth, [self.__no_criterion_condition])
+ d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic)
+ ent.setup_entity(entity, *d[entity])
- def __no_statement(self, meth):
- return self.__conditional_clone(meth, [self.__no_statement_condition])
-
- def __reset_all(self, mapper, meth):
- q = self.__conditional_clone(meth, [self.__no_criterion_condition])
- q.__init_mapper(mapper, mapper)
- return q
+ def __mapper_loads_polymorphically_with(self, mapper, adapter):
+ for m2 in mapper._with_polymorphic_mappers:
+ for m in m2.iterate_to_root():
+ self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
def __set_select_from(self, from_obj):
if isinstance(from_obj, expression._SelectBaseMixin):
from_obj = from_obj.alias()
self._from_obj = from_obj
- self._alias_ids = {}
+ equivs = self.__all_equivs()
+
+ if isinstance(from_obj, expression.Alias):
+ # dont alias a regular join (since its not an alias itself)
+ self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj, equivs)
+
+ def _get_polymorphic_adapter(self, entity, selectable):
+ self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns))
+
+ def _reset_polymorphic_adapter(self, mapper):
+ for m2 in mapper._with_polymorphic_mappers:
+ for m in m2.iterate_to_root():
+ self._polymorphic_adapters.pop(m.mapped_table, None)
+ self._polymorphic_adapters.pop(m.local_table, None)
+
+ def __reset_joinpoint(self):
+ self._joinpoint = None
+ self._filter_aliases = None
+
+ def __adapt_polymorphic_element(self, element):
+ if isinstance(element, expression.FromClause):
+ search = element
+ elif hasattr(element, 'table'):
+ search = element.table
+ else:
+ search = None
+
+ if search:
+ alias = self._polymorphic_adapters.get(search, None)
+ if alias:
+ return alias.adapt_clause(element)
+
+ def __replace_element(self, adapters):
+ def replace(elem):
+ if '_halt_adapt' in elem._annotations:
+ return elem
+
+ for adapter in adapters:
+ e = adapter(elem)
+ if e:
+ return e
+ return replace
+
+ def __replace_orm_element(self, adapters):
+ def replace(elem):
+ if '_halt_adapt' in elem._annotations:
+ return elem
+
+ if "_orm_adapt" in elem._annotations or "parententity" in elem._annotations:
+ for adapter in adapters:
+ e = adapter(elem)
+ if e:
+ return e
+ return replace
+
+ def _adapt_all_clauses(self):
+ self._disable_orm_filtering = True
+ _adapt_all_clauses = _generative()(_adapt_all_clauses)
+
+ def _adapt_clause(self, clause, as_filter, orm_only):
+ adapters = []
+ if as_filter and self._filter_aliases:
+ adapters.append(self._filter_aliases.replace)
+
+ if self._polymorphic_adapters:
+ adapters.append(self.__adapt_polymorphic_element)
+
+ if self._from_obj_alias:
+ adapters.append(self._from_obj_alias.replace)
+
+ if not adapters:
+ return clause
+
+ if getattr(self, '_disable_orm_filtering', not orm_only):
+ return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_element(adapters))
+ else:
+ return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_orm_element(adapters))
- if self.table not in self._get_joinable_tables():
- self._aliases_head = self._aliases_tail = mapperutil.AliasedClauses(self._from_obj, equivalents=self.mapper._equivalent_columns)
- self._alias_ids.setdefault(self.table, []).append(self._aliases_head)
+ def _entity_zero(self):
+ return self._entities[0]
+
+ def _mapper_zero(self):
+ return self._entity_zero().entity_zero
+
+ def _extension_zero(self):
+ ent = self._entity_zero()
+ return getattr(ent, 'extension', ent.mapper.extension)
+
+ def _mapper_entities(self):
+ for ent in self._entities:
+ if hasattr(ent, 'primary_entity'):
+ yield ent
+ _mapper_entities = property(_mapper_entities)
+
+ def _joinpoint_zero(self):
+ return self._joinpoint or self._entity_zero().entity_zero
+
+ def _mapper_zero_or_none(self):
+ if not getattr(self._entities[0], 'primary_entity', False):
+ return None
+ return self._entities[0].mapper
+
+ def _only_mapper_zero(self):
+ if len(self._entities) > 1:
+ raise sa_exc.InvalidRequestError("This operation requires a Query against a single mapper.")
+ return self._mapper_zero()
+
+ def _only_entity_zero(self):
+ if len(self._entities) > 1:
+ raise sa_exc.InvalidRequestError("This operation requires a Query against a single mapper.")
+ return self._entity_zero()
+
+ def _generate_mapper_zero(self):
+ if not getattr(self._entities[0], 'primary_entity', False):
+ raise sa_exc.InvalidRequestError("No primary mapper set up for this Query.")
+ entity = self._entities[0]._clone()
+ self._entities = [entity] + self._entities[1:]
+ return entity
+
+ def __mapper_zero_from_obj(self):
+ if self._from_obj:
+ return self._from_obj
else:
- self._aliases_head = self._aliases_tail = None
+ return self._entity_zero().selectable
- def __set_with_polymorphic(self, cls_or_mappers, selectable=None):
- mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable)
- self._with_polymorphic = mappers
- self.__set_select_from(from_obj)
+ def __all_equivs(self):
+ equivs = {}
+ for ent in self._mapper_entities:
+ equivs.update(ent.mapper._equivalent_columns)
+ return equivs
- def __no_criterion_condition(self, q, meth):
- if q._criterion or q._statement:
+ def __no_criterion_condition(self, meth):
+ if self._criterion or self._statement or self._from_obj:
util.warn(
("Query.%s() being called on a Query with existing criterion; "
- "criterion is being ignored.") % meth)
-
- q._joinpoint = self.mapper
- q._statement = q._criterion = None
- q._order_by = q._group_by = q._distinct = False
- q._aliases_tail = q._aliases_head
- q.table = q._from_obj = q.mapper.mapped_table
- if q.mapper.with_polymorphic:
- q.__set_with_polymorphic(*q.mapper.with_polymorphic)
-
- def __no_entities(self, meth):
- q = self.__no_statement(meth)
- if len(q._entities) > 1 and not isinstance(q._entities[0], _PrimaryMapperEntity):
- raise exceptions.InvalidRequestError(
- ("Query.%s() being called on a Query with existing "
- "additional entities or columns - can't replace columns") % meth)
- q._entities = []
- return q
+ "criterion is being ignored. This usage is deprecated.") % meth)
- def __no_statement_condition(self, q, meth):
- if q._statement:
- raise exceptions.InvalidRequestError(
+ self._statement = self._criterion = self._from_obj = None
+ self._order_by = self._group_by = self._distinct = False
+ self.__joined_tables = {}
+
+ def __no_from_condition(self, meth):
+ if self._from_obj:
+ raise sa_exc.InvalidRequestError("Query.%s() being called on a Query which already has a FROM clause established. This usage is deprecated." % meth)
+
+ def __no_statement_condition(self, meth):
+ if self._statement:
+ raise sa_exc.InvalidRequestError(
("Query.%s() being called on a Query with an existing full "
"statement - can't apply criterion.") % meth)
- def __conditional_clone(self, methname=None, conditions=None):
- q = self._clone()
- if conditions:
- for condition in conditions:
- condition(q, methname)
- return q
+ def __no_limit_offset(self, meth):
+ if self._limit or self._offset:
+ util.warn("Query.%s() being called on a Query which already has LIMIT or OFFSET applied. "
+ "This usage is deprecated. Apply filtering and joins before LIMIT or OFFSET are applied, "
+ "or to filter/join to the row-limited results of the query, call from_self() first."
+ "In release 0.5, from_self() will be called automatically in this scenario."
+ )
+
+ def __no_criterion(self):
+ """generate a Query with no criterion, warn if criterion was present"""
+ __no_criterion = _generative(__no_criterion_condition)(__no_criterion)
def __get_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None):
if populate_existing:
q.__dict__ = self.__dict__.copy()
return q
- def session(self):
- if self._session is None:
- return self.mapper.get_session()
- else:
- return self._session
- session = property(session)
-
def statement(self):
"""return the full SELECT statement represented by this Query."""
- return self._compile_context().statement
+ return self._compile_context(labels=self._with_labels).statement
statement = property(statement)
+ def with_labels(self):
+ """Apply column labels to the return value of Query.statement.
+
+ Indicates that this Query's `statement` accessor should return a SELECT statement
+ that applies labels to all columns in the form <tablename>_<columnname>; this
+ is commonly used to disambiguate columns from multiple tables which have the
+ same name.
+
+ When the `Query` actually issues SQL to load rows, it always uses
+ column labeling.
+
+ """
+ self._with_labels = True
+ with_labels = _generative()(with_labels)
+
+
def whereclause(self):
"""return the WHERE criterion for this Query."""
return self._criterion
def _with_current_path(self, path):
"""indicate that this query applies to objects loaded within a certain path.
-
- Used by deferred loaders (see strategies.py) which transfer query
+
+ Used by deferred loaders (see strategies.py) which transfer query
options from an originating query to a newly generated query intended
for the deferred load.
-
+
"""
- q = self._clone()
- q._current_path = path
- return q
+ self._current_path = path
+ _with_current_path = _generative()(_with_current_path)
def with_polymorphic(self, cls_or_mappers, selectable=None):
"""Load columns for descendant mappers of this Query's mapper.
-
+
Using this method will ensure that each descendant mapper's
- tables are included in the FROM clause, and will allow filter()
- criterion to be used against those tables. The resulting
+ tables are included in the FROM clause, and will allow filter()
+ criterion to be used against those tables. The resulting
instances will also have those columns already loaded so that
no "post fetch" of those columns will be required.
-
+
``cls_or_mappers`` is a single class or mapper, or list of class/mappers,
which inherit from this Query's mapper. Alternatively, it
- may also be the string ``'*'``, in which case all descending
+ may also be the string ``'*'``, in which case all descending
mappers will be added to the FROM clause.
-
- ``selectable`` is a table or select() statement that will
+
+ ``selectable`` is a table or select() statement that will
be used in place of the generated FROM clause. This argument
- is required if any of the desired mappers use concrete table
- inheritance, since SQLAlchemy currently cannot generate UNIONs
- among tables automatically. If used, the ``selectable``
- argument must represent the full set of tables and columns mapped
+ is required if any of the desired mappers use concrete table
+ inheritance, since SQLAlchemy currently cannot generate UNIONs
+ among tables automatically. If used, the ``selectable``
+ argument must represent the full set of tables and columns mapped
by every desired mapper. Otherwise, the unaccounted mapped columns
- will result in their table being appended directly to the FROM
+ will result in their table being appended directly to the FROM
clause which will usually lead to incorrect results.
"""
- q = self.__no_criterion('with_polymorphic')
-
- q.__set_with_polymorphic(cls_or_mappers, selectable=selectable)
+ entity = self._generate_mapper_zero()
+ entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable)
+ with_polymorphic = _generative(__no_from_condition, __no_criterion_condition)(with_polymorphic)
- return q
-
-
def yield_per(self, count):
"""Yield only ``count`` rows at a time.
eagerly loaded collections (i.e. any lazy=False) since those
collections will be cleared for a new load when encountered in a
subsequent result batch.
- """
- q = self._clone()
- q._yield_per = count
- return q
+ """
+ self._yield_per = count
+ yield_per = _generative()(yield_per)
def get(self, ident, **kwargs):
"""Return an instance of the object based on the given identifier, or None if not found.
The `ident` argument is a scalar or tuple of primary key column values
in the order of the table def's primary key columns.
+
"""
- ret = self._extension.get(self, ident, **kwargs)
+ ret = self._extension_zero().get(self, ident, **kwargs)
if ret is not mapper.EXT_CONTINUE:
return ret
# convert composite types to individual args
- # TODO: account for the order of columns in the
- # ColumnProperty it corresponds to
if hasattr(ident, '__composite_values__'):
ident = ident.__composite_values__()
- key = self.mapper.identity_key_from_primary_key(ident)
+ key = self._only_mapper_zero().identity_key_from_primary_key(ident)
return self._get(key, ident, **kwargs)
def load(self, ident, raiseerr=True, **kwargs):
pending changes** to the object already existing in the Session. The
`ident` argument is a scalar or tuple of primary key column values in
the order of the table def's primary key columns.
- """
- ret = self._extension.load(self, ident, **kwargs)
+ """
+ ret = self._extension_zero().load(self, ident, **kwargs)
if ret is not mapper.EXT_CONTINUE:
return ret
- key = self.mapper.identity_key_from_primary_key(ident)
+
+ # convert composite types to individual args
+ if hasattr(ident, '__composite_values__'):
+ ident = ident.__composite_values__()
+
+ key = self._only_mapper_zero().identity_key_from_primary_key(ident)
instance = self.populate_existing()._get(key, ident, **kwargs)
if instance is None and raiseerr:
- raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
+ raise sa_exc.InvalidRequestError("No instance found for identity %s" % repr(ident))
return instance
def query_from_parent(cls, instance, property, **kwargs):
\**kwargs
all extra keyword arguments are propagated to the constructor of
Query.
- """
+ deprecated. use sqlalchemy.orm.with_parent in conjunction with
+ filter().
+
+ """
mapper = object_mapper(instance)
prop = mapper.get_property(property, resolve_synonyms=True)
target = prop.mapper
criterion = prop.compare(operators.eq, instance, value_is_parent=True)
return Query(target, **kwargs).filter(criterion)
- query_from_parent = classmethod(query_from_parent)
+ query_from_parent = classmethod(util.deprecated(None, False)(query_from_parent))
+
+ def correlate(self, *args):
+ self._correlate = self._correlate.union([_orm_selectable(s) for s in args])
+ correlate = _generative()(correlate)
def autoflush(self, setting):
"""Return a Query with a specific 'autoflush' setting.
Note that a Session with autoflush=False will
- not autoflush, even if this flag is set to True at the
+ not autoflush, even if this flag is set to True at the
Query level. Therefore this flag is usually used only
to disable autoflush for a specific Query.
-
+
"""
- q = self._clone()
- q._autoflush = setting
- return q
+ self._autoflush = setting
+ autoflush = _generative()(autoflush)
def populate_existing(self):
"""Return a Query that will refresh all instances loaded.
An alternative to populate_existing() is to expire the Session
fully using session.expire_all().
-
+
"""
- q = self._clone()
- q._populate_existing = True
- return q
+ self._populate_existing = True
+ populate_existing = _generative()(populate_existing)
def with_parent(self, instance, property=None):
"""add a join criterion corresponding to a relationship to the given parent instance.
mapper = object_mapper(instance)
if property is None:
for prop in mapper.iterate_properties:
- if isinstance(prop, properties.PropertyLoader) and prop.mapper is self.mapper:
+ if isinstance(prop, properties.PropertyLoader) and prop.mapper is self._mapper_zero():
break
else:
- raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__))
+ raise sa_exc.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self._mapper_zero().class_.__name__, instance.__class__.__name__))
else:
prop = mapper.get_property(property, resolve_synonyms=True)
return self.filter(prop.compare(operators.eq, instance, value_is_parent=True))
- def add_entity(self, entity, alias=None, id=None):
- """add a mapped entity to the list of result columns to be returned.
-
- This will have the effect of all result-returning methods returning a tuple
- of results, the first element being an instance of the primary class for this
- Query, and subsequent elements matching columns or entities which were
- specified via add_column or add_entity.
-
- When adding entities to the result, its generally desirable to add
- limiting criterion to the query which can associate the primary entity
- of this Query along with the additional entities. The Query selects
- from all tables with no joining criterion by default.
+ def add_entity(self, entity, alias=None):
+ """add a mapped entity to the list of result columns to be returned."""
- entity
- a class or mapper which will be added to the results.
+ if alias:
+ entity = aliased(entity, alias)
- alias
- a sqlalchemy.sql.Alias object which will be used to select rows. this
- will match the usage of the given Alias in filter(), order_by(), etc. expressions
+ self._entities = list(self._entities)
+ m = _MapperEntity(self, entity)
+ self.__setup_aliasizers([m])
+ add_entity = _generative()(add_entity)
- id
- a string ID matching that given to query.join() or query.outerjoin(); rows will be
- selected from the aliased join created via those methods.
+ def from_self(self, *entities):
+ """return a Query that selects from this Query's SELECT statement.
+ \*entities - optional list of entities which will replace
+ those being selected.
"""
- q = self._clone()
-
- if not alias and _is_aliased_class(entity):
- alias = entity.alias
- if isinstance(entity, type):
- entity = mapper.class_mapper(entity)
+ fromclause = self.compile().correlate(None)
+ self._statement = self._criterion = None
+ self._order_by = self._group_by = self._distinct = False
+ self._limit = self._offset = None
+ self.__set_select_from(fromclause)
+ if entities:
+ self._entities = []
+ for ent in entities:
+ _QueryEntity(self, ent)
+ self.__setup_aliasizers(self._entities)
- if alias is not None:
- alias = mapperutil.AliasedClauses(alias)
+ from_self = _generative()(from_self)
+ _from_self = from_self
- q._entities = q._entities + [_MapperEntity(mapper=entity, alias=alias, id=id)]
- return q
-
- def _from_self(self):
- """return a Query that selects from this Query's SELECT statement.
-
- The API for this method hasn't been decided yet and is subject to change.
-
- """
- q = self._clone()
- q._eager_loaders = util.Set()
- fromclause = q.compile().correlate(None)
- return Query(self.mapper, self.session).select_from(fromclause)
-
def values(self, *columns):
"""Return an iterator yielding result tuples corresponding to the given list of columns"""
-
- q = self.__no_entities('_values')
- q._only_load_props = q._eager_loaders = util.Set()
- q._no_filters = True
+
+ if not columns:
+ return iter(())
+ q = self._clone()
+ q._entities = []
for column in columns:
- q._entities.append(self._add_column(column, None, False))
+ _ColumnEntity(q, column)
+ q.__setup_aliasizers(q._entities)
if not q._yield_per:
- q = q.yield_per(10)
+ q._yield_per = 10
return iter(q)
_values = values
-
- def add_column(self, column, id=None):
- """Add a SQL ColumnElement to the list of result columns to be returned.
- This will have the effect of all result-returning methods returning a
- tuple of results, the first element being an instance of the primary
- class for this Query, and subsequent elements matching columns or
- entities which were specified via add_column or add_entity.
+ def add_column(self, column):
+ """Add a SQL ColumnElement to the list of result columns to be returned."""
- When adding columns to the result, its generally desirable to add
- limiting criterion to the query which can associate the primary entity
- of this Query along with the additional columns, if the column is
- based on a table or selectable that is not the primary mapped
- selectable. The Query selects from all tables with no joining
- criterion by default.
+ self._entities = list(self._entities)
+ c = _ColumnEntity(self, column)
+ self.__setup_aliasizers([c])
+ add_column = _generative()(add_column)
- column
- a string column name or sql.ColumnElement to be added to the results.
-
- """
- q = self._clone()
- q._entities = q._entities + [self._add_column(column, id, True)]
- return q
-
- def _add_column(self, column, id, looks_for_aliases):
- if isinstance(column, interfaces.PropComparator):
- column = column.clause_element()
-
- elif not isinstance(column, (sql.ColumnElement, basestring)):
- raise exceptions.InvalidRequestError("Invalid column expression '%r'" % column)
-
- return _ColumnEntity(column, id)
-
def options(self, *args):
"""Return a new Query object, applying the given list of
MapperOptions.
"""
- return self._options(False, *args)
+ return self.__options(False, *args)
def _conditional_options(self, *args):
- return self._options(True, *args)
+ return self.__options(True, *args)
- def _options(self, conditional, *args):
- q = self._clone()
+ def __options(self, conditional, *args):
# most MapperOptions write to the '_attributes' dictionary,
# so copy that as well
- q._attributes = q._attributes.copy()
+ self._attributes = self._attributes.copy()
opts = [o for o in util.flatten_iterator(args)]
- q._with_options = q._with_options + opts
+ self._with_options = self._with_options + opts
if conditional:
for opt in opts:
- opt.process_query_conditionally(q)
+ opt.process_query_conditionally(self)
else:
for opt in opts:
- opt.process_query(q)
- return q
+ opt.process_query(self)
+ __options = _generative()(__options)
def with_lockmode(self, mode):
"""Return a new Query object with the specified locking mode."""
-
- q = self._clone()
- q._lockmode = mode
- return q
+
+ self._lockmode = mode
+ with_lockmode = _generative()(with_lockmode)
def params(self, *args, **kwargs):
"""add values for bind parameters which may have been specified in filter().
\**kwargs cannot be used.
"""
- q = self._clone()
if len(args) == 1:
kwargs.update(args[0])
elif len(args) > 0:
- raise exceptions.ArgumentError("params() takes zero or one positional argument, which is a dictionary.")
- q._params = q._params.copy()
- q._params.update(kwargs)
- return q
+ raise sa_exc.ArgumentError("params() takes zero or one positional argument, which is a dictionary.")
+ self._params = self._params.copy()
+ self._params.update(kwargs)
+ params = _generative()(params)
def filter(self, criterion):
"""apply the given filtering criterion to the query and return the newly resulting ``Query``
criterion = sql.text(criterion)
if criterion is not None and not isinstance(criterion, sql.ClauseElement):
- raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
+ raise sa_exc.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
- if self._aliases_tail:
- criterion = self._aliases_tail.adapt_clause(criterion)
+ criterion = self._adapt_clause(criterion, True, True)
- q = self.__no_statement("filter")
- if q._criterion is not None:
- q._criterion = q._criterion & criterion
+ if self._criterion is not None:
+ self._criterion = self._criterion & criterion
else:
- q._criterion = criterion
- return q
+ self._criterion = criterion
+ filter = _generative(__no_statement_condition, __no_limit_offset)(filter)
def filter_by(self, **kwargs):
"""apply the given filtering criterion to the query and return the newly resulting ``Query``."""
- clauses = [self._joinpoint.get_property(key, resolve_synonyms=True).compare(operators.eq, value)
+ clauses = [_entity_descriptor(self._joinpoint_zero(), key)[0] == value
for key, value in kwargs.iteritems()]
return self.filter(sql.and_(*clauses))
def order_by(self, *criterion):
"""apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
- q = self.__no_statement("order_by")
+ criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion]
- if self._aliases_tail:
- criterion = tuple(self._aliases_tail.adapt_list(
- [expression._literal_as_text(o) for o in criterion]
- ))
-
- if q._order_by is False:
- q._order_by = criterion
+ if self._order_by is False:
+ self._order_by = criterion
else:
- q._order_by = q._order_by + criterion
- return q
+ self._order_by = self._order_by + criterion
order_by = util.array_as_starargs_decorator(order_by)
-
+ order_by = _generative(__no_statement_condition)(order_by)
+
def group_by(self, *criterion):
"""apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""
- q = self.__no_statement("group_by")
- if q._group_by is False:
- q._group_by = criterion
+ criterion = list(chain(*[_orm_columns(c) for c in criterion]))
+
+ if self._group_by is False:
+ self._group_by = criterion
else:
- q._group_by = q._group_by + criterion
- return q
+ self._group_by = self._group_by + criterion
group_by = util.array_as_starargs_decorator(group_by)
-
+ group_by = _generative(__no_statement_condition)(group_by)
+
def having(self, criterion):
"""apply a HAVING criterion to the query and return the newly resulting ``Query``."""
criterion = sql.text(criterion)
if criterion is not None and not isinstance(criterion, sql.ClauseElement):
- raise exceptions.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string")
+ raise sa_exc.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string")
- if self._aliases_tail:
- criterion = self._aliases_tail.adapt_clause(criterion)
+ criterion = self._adapt_clause(criterion, True, True)
- q = self.__no_statement("having")
- if q._having is not None:
- q._having = q._having & criterion
+ if self._having is not None:
+ self._having = self._having & criterion
else:
- q._having = criterion
- return q
+ self._having = criterion
+ having = _generative(__no_statement_condition)(having)
- def join(self, prop, id=None, aliased=False, from_joinpoint=False):
+ def join(self, *props, **kwargs):
"""Create a join against this ``Query`` object's criterion
- and apply generatively, retunring the newly resulting ``Query``.
-
- 'prop' may be one of:
- * a string property name, i.e. "rooms"
- * a class-mapped attribute, i.e. Houses.rooms
- * a 2-tuple containing one of the above, combined with a selectable
- which derives from the properties' mapped table
- * a list (not a tuple) containing a combination of any of the above.
+ and apply generatively, returning the newly resulting ``Query``.
+ each element in \*props may be:
+
+ * a string property name, i.e. "rooms". This will join along
+ the relation of the same name from this Query's "primary"
+ mapper, if one is present.
+
+ * a class-mapped attribute, i.e. Houses.rooms. This will create a
+ join from "Houses" table to that of the "rooms" relation.
+
+ * a 2-tuple containing a target class or selectable, and
+ an "ON" clause. The ON clause can be the property name/
+ attribute like above, or a SQL expression.
+
+
e.g.::
+ # join along string attribute names
session.query(Company).join('employees')
- session.query(Company).join(['employees', 'tasks'])
- session.query(Houses).join([Colonials.rooms, Room.closets])
- session.query(Company).join([('employees', people.join(engineers)), Engineer.computers])
+ session.query(Company).join('employees', 'tasks')
+
+ # join the Person entity to an alias of itself,
+ # along the "friends" relation
+ PAlias = aliased(Person)
+ session.query(Person).join((Palias, Person.friends))
+
+ # join from Houses to the "rooms" attribute on the
+ # "Colonials" subclass of Houses, then join to the
+ # "closets" relation on Room
+ session.query(Houses).join(Colonials.rooms, Room.closets)
+
+ # join from Company entities to the "employees" collection,
+ # using "people JOIN engineers" as the target. Then join
+ # to the "computers" collection on the Engineer entity.
+ session.query(Company).join((people.join(engineers), 'employees'), Engineer.computers)
+
+ # join from Articles to Keywords, using the "keywords" attribute.
+ # assume this is a many-to-many relation.
+ session.query(Article).join(Article.keywords)
+
+ # same thing, but spelled out entirely explicitly
+ # including the association table.
+ session.query(Article).join(
+ (article_keywords, Articles.id==article_keywords.c.article_id),
+ (Keyword, Keyword.id==article_keywords.c.keyword_id)
+ )
+
+ \**kwargs include:
+
+ aliased - when joining, create anonymous aliases of each table. This is
+ used for self-referential joins or multiple joins to the same table.
+ Consider usage of the aliased(SomeClass) construct as a more explicit
+ approach to this.
+
+ from_joinpoint - when joins are specified using string property names,
+ locate the property from the mapper found in the most recent previous
+ join() call, instead of from the root entity.
"""
- return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint)
+ aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
+ if kwargs:
+ raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys()))
+ return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint)
+ join = util.array_as_starargs_decorator(join)
- def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False):
+ def outerjoin(self, *props, **kwargs):
"""Create a left outer join against this ``Query`` object's criterion
and apply generatively, retunring the newly resulting ``Query``.
+
+ Usage is the same as the ``join()`` method.
- 'prop' may be one of:
- * a string property name, i.e. "rooms"
- * a class-mapped attribute, i.e. Houses.rooms
- * a 2-tuple containing one of the above, combined with a selectable
- which derives from the properties' mapped table
- * a list (not a tuple) containing a combination of any of the above.
+ """
+ aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
+ if kwargs:
+ raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys()))
+ return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint)
+ outerjoin = util.array_as_starargs_decorator(outerjoin)
- e.g.::
+ def __join(self, keys, outerjoin, create_aliases, from_joinpoint):
+ self.__currenttables = util.Set(self.__currenttables)
+ self._polymorphic_adapters = self._polymorphic_adapters.copy()
- session.query(Company).outerjoin('employees')
- session.query(Company).outerjoin(['employees', 'tasks'])
- session.query(Houses).outerjoin([Colonials.rooms, Room.closets])
- session.query(Company).join([('employees', people.join(engineers)), Engineer.computers])
+ if not from_joinpoint:
+ self.__reset_joinpoint()
- """
- return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint)
-
- def _join(self, prop, id, outerjoin, aliased, from_joinpoint):
- (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased)
- # TODO: improve the generative check here to look for primary mapped entity, etc.
- q = self.__no_statement("join")
- q._from_obj = clause
- q._joinpoint = mapper
- q._aliases = aliases
- q.__generate_alias_ids()
-
- if aliases:
- q._aliases_tail = aliases
-
- a = aliases
- while a is not None:
- if isinstance(a, mapperutil.PropertyAliasedClauses):
- q._alias_ids.setdefault(a.mapper, []).append(a)
- q._alias_ids.setdefault(a.table, []).append(a)
- a = a.parentclauses
+ clause = self._from_obj
+ right_entity = None
+
+ for arg1 in util.to_list(keys):
+ prop = None
+ aliased_entity = False
+ alias_criterion = False
+ left_entity = right_entity
+ right_entity = right_mapper = None
+
+ if isinstance(arg1, tuple):
+ arg1, arg2 = arg1
else:
- break
+ arg2 = None
+
+ if isinstance(arg2, (interfaces.PropComparator, basestring)):
+ onclause = arg2
+ right_entity = arg1
+ elif isinstance(arg1, (interfaces.PropComparator, basestring)):
+ onclause = arg1
+ right_entity = arg2
+ else:
+ onclause = arg2
+ right_entity = arg1
- if id:
- q._alias_ids[id] = [aliases]
- return q
+ if isinstance(onclause, interfaces.PropComparator):
+ of_type = getattr(onclause, '_of_type', None)
+ prop = onclause.property
+ descriptor = onclause
+
+ if not left_entity:
+ left_entity = onclause.parententity
+
+ if of_type:
+ right_mapper = of_type
+ else:
+ right_mapper = prop.mapper
+
+ if not right_entity:
+ right_entity = right_mapper
+
+ elif isinstance(onclause, basestring):
+ if not left_entity:
+ left_entity = self._joinpoint_zero()
+
+ descriptor, prop = _entity_descriptor(left_entity, onclause)
+ right_mapper = prop.mapper
+ if not right_entity:
+ right_entity = right_mapper
+ elif onclause is None:
+ if not left_entity:
+ left_entity = self._joinpoint_zero()
+ else:
+ if not left_entity:
+ left_entity = self._joinpoint_zero()
+
+ if not clause:
+ if isinstance(onclause, interfaces.PropComparator):
+ clause = onclause.__clause_element__()
- def _get_joinable_tables(self):
- if not self.__joinable_tables or self.__joinable_tables[0] is not self._from_obj:
- currenttables = [self._from_obj]
- def visit_join(join):
- currenttables.append(join.left)
- currenttables.append(join.right)
- visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False})
- self.__joinable_tables = (self._from_obj, currenttables)
- return currenttables
- else:
- return self.__joinable_tables[1]
+ for ent in self._mapper_entities:
+ if ent.corresponds_to(left_entity):
+ clause = ent.selectable
+ break
- def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
- if start is None:
- start = self._joinpoint
+ if not clause:
+ raise exc.InvalidRequestError("Could not find a FROM clause to join from")
- clause = self._from_obj
+ bogus, right_selectable, is_aliased_class = _entity_info(right_entity)
- currenttables = self._get_joinable_tables()
+ if right_mapper and not is_aliased_class:
+ if right_entity is right_selectable:
- # determine if generated joins need to be aliased on the left
- # hand side.
- if self._aliases_head is self._aliases_tail is not None:
- adapt_against = self._aliases_tail.alias
- elif start is not self.mapper and self._aliases_tail:
- adapt_against = self._aliases_tail.alias
- else:
- adapt_against = None
+ if not right_selectable.is_derived_from(right_mapper.mapped_table):
+ raise sa_exc.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (right_selectable.description, right_mapper.mapped_table.description))
- mapper = start
- alias = self._aliases_tail
+ if not isinstance(right_selectable, expression.Alias):
+ right_selectable = right_selectable.alias()
- if not isinstance(keys, list):
- keys = [keys]
-
- for key in keys:
- use_selectable = None
- of_type = None
- is_aliased_class = False
-
- if isinstance(key, tuple):
- key, use_selectable = key
-
- if isinstance(key, interfaces.PropComparator):
- prop = key.property
- if getattr(key, '_of_type', None):
- of_type = key._of_type
- if not use_selectable:
- use_selectable = key._of_type.mapped_table
- else:
- prop = mapper.get_property(key, resolve_synonyms=True)
-
- if use_selectable:
- if _is_aliased_class(use_selectable):
- use_selectable = use_selectable.alias
- is_aliased_class = True
- if not use_selectable.is_derived_from(prop.mapper.mapped_table):
- raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description))
- if not isinstance(use_selectable, expression.Alias):
- use_selectable = use_selectable.alias()
- elif prop.mapper.with_polymorphic:
- use_selectable = prop.mapper._with_polymorphic_selectable()
- if not isinstance(use_selectable, expression.Alias):
- use_selectable = use_selectable.alias()
-
- if prop._is_self_referential() and not create_aliases and not use_selectable:
- raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires aliased=True argument." % str(prop))
-
- if prop.table not in currenttables or create_aliases or use_selectable:
+ right_entity = aliased(right_mapper, right_selectable)
+ alias_criterion = True
+
+ elif right_mapper.with_polymorphic or isinstance(right_mapper.mapped_table, expression.Join):
+ aliased_entity = True
+ right_entity = aliased(right_mapper)
+ alias_criterion = True
- if use_selectable or create_aliases:
- alias = mapperutil.PropertyAliasedClauses(prop,
- prop.primaryjoin,
- prop.secondaryjoin,
- alias,
- alias=use_selectable,
- should_adapt=not is_aliased_class
- )
- crit = alias.primaryjoin
+ elif create_aliases:
+ right_entity = aliased(right_mapper)
+ alias_criterion = True
+
+ elif prop:
+ if prop.table in self.__currenttables:
+ if prop.secondary is not None and prop.secondary not in self.__currenttables:
+ # TODO: this check is not strong enough for different paths to the same endpoint which
+ # does not use secondary tables
+ raise sa_exc.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % descriptor)
+
+ continue
+
if prop.secondary:
- clause = clause.join(alias.secondary, crit, isouter=outerjoin)
- clause = clause.join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
- else:
- clause = clause.join(alias.alias, crit, isouter=outerjoin)
- else:
- assert not prop.mapper.with_polymorphic
- pj, sj, source, dest, target_adapter = prop._create_joins(source_selectable=adapt_against)
- if sj:
- clause = clause.join(prop.secondary, pj, isouter=outerjoin)
- clause = clause.join(prop.table, sj, isouter=outerjoin)
- else:
- clause = clause.join(prop.table, pj, isouter=outerjoin)
-
- elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables:
- # TODO: this check is not strong enough for different paths to the same endpoint which
- # does not use secondary tables
- raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key)
+ self.__currenttables.add(prop.secondary)
+ self.__currenttables.add(prop.table)
- mapper = of_type or prop.mapper
+ right_entity = prop.mapper
- if use_selectable:
- adapt_against = use_selectable
-
- return (clause, mapper, alias)
+ if prop:
+ onclause = prop
+
+ clause = orm_join(clause, right_entity, onclause, isouter=outerjoin)
+ if alias_criterion:
+ self._filter_aliases = ORMAdapter(right_entity,
+ equivalents=right_mapper._equivalent_columns, chain_to=self._filter_aliases)
+
+ if aliased_entity:
+ self.__mapper_loads_polymorphically_with(right_mapper, ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns))
+
+ self._from_obj = clause
+ self._joinpoint = right_entity
+ __join = _generative(__no_statement_condition, __no_limit_offset)(__join)
def reset_joinpoint(self):
"""return a new Query reset the 'joinpoint' of this Query reset
the root.
"""
- q = self.__no_statement("reset_joinpoint")
- q._joinpoint = q.mapper
- if q.table not in q._get_joinable_tables():
- q._aliases_head = q._aliases_tail = mapperutil.AliasedClauses(q._from_obj, equivalents=q.mapper._equivalent_columns)
- else:
- q._aliases_head = q._aliases_tail = None
- return q
+ self.__reset_joinpoint()
+ reset_joinpoint = _generative(__no_statement_condition)(reset_joinpoint)
def select_from(self, from_obj):
"""Set the `from_obj` parameter of the query and return the newly
`from_obj` is a single table or selectable.
"""
- new = self.__no_criterion('select_from')
if isinstance(from_obj, (tuple, list)):
util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.")
from_obj = from_obj[-1]
+
+ self.__set_select_from(from_obj)
+ select_from = _generative(__no_from_condition, __no_criterion_condition)(select_from)
- new.__set_select_from(from_obj)
- return new
-
def __getitem__(self, item):
if isinstance(item, slice):
start = item.start
``Query``.
"""
- new = self.__no_statement("distinct")
- new._distinct = True
- return new
+ self._distinct = True
+ distinct = _generative(__no_statement_condition)(distinct)
def all(self):
"""Return the results represented by this ``Query`` as a list.
"""
return list(self)
-
def from_statement(self, statement):
"""Execute the given SELECT statement and return results.
"""
if isinstance(statement, basestring):
statement = sql.text(statement)
- q = self.__no_criterion('from_statement')
- q._statement = statement
- return q
+ self._statement = statement
+ from_statement = _generative(__no_criterion_condition)(from_statement)
def first(self):
"""Return the first result of this ``Query`` or None if the result doesn't contain any row.
This results in an execution of the underlying query.
"""
- if self._column_aggregate is not None:
- return self._col_aggregate(*self._column_aggregate)
-
ret = list(self[0:1])
if len(ret) > 0:
return ret[0]
This results in an execution of the underlying query.
"""
- if self._column_aggregate is not None:
- return self._col_aggregate(*self._column_aggregate)
-
ret = list(self[0:2])
if len(ret) == 1:
return ret[0]
elif len(ret) == 0:
- raise exceptions.InvalidRequestError('No rows returned for one()')
+ raise sa_exc.InvalidRequestError('No rows returned for one()')
else:
- raise exceptions.InvalidRequestError('Multiple rows returned for one()')
+ raise sa_exc.InvalidRequestError('Multiple rows returned for one()')
def __iter__(self):
context = self._compile_context()
return self._execute_and_instances(context)
def _execute_and_instances(self, querycontext):
- result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper, instance=self._refresh_instance)
- return self.iterate_instances(result, querycontext=querycontext)
+ result = self.session.execute(querycontext.statement, params=self._params, mapper=self._mapper_zero_or_none(), instance=self._refresh_instance)
+ return self.iterate_instances(result, querycontext)
- def instances(self, cursor, *mappers_or_columns, **kwargs):
- return list(self.iterate_instances(cursor, *mappers_or_columns, **kwargs))
+ def instances(self, cursor, __context=None):
+ return list(self.iterate_instances(cursor, __context))
- def iterate_instances(self, cursor, *mappers_or_columns, **kwargs):
+ def iterate_instances(self, cursor, __context=None):
session = self.session
- context = kwargs.pop('querycontext', None)
+ context = __context
if context is None:
context = QueryContext(self)
context.runid = _new_runid()
- entities = self._entities + [_QueryEntity.legacy_guess_type(mc) for mc in mappers_or_columns]
-
- if getattr(self, '_no_filters', False):
- filter = None
- single_entity = custom_rows = False
- else:
- single_entity = isinstance(entities[0], _PrimaryMapperEntity) and len(entities) == 1
- custom_rows = single_entity and 'append_result' in context.extension.methods
-
+ filtered = bool(list(self._mapper_entities))
+ single_entity = filtered and len(self._entities) == 1
+
+ if filtered:
if single_entity:
filter = util.OrderedIdentitySet
else:
filter = util.OrderedSet
-
- process = [query_entity.row_processor(self, context, single_entity) for query_entity in entities]
+ else:
+ filter = None
+
+ custom_rows = single_entity and 'append_result' in self._entities[0].extension.methods
+ (process, labels) = zip(*[query_entity.row_processor(self, context, custom_rows) for query_entity in self._entities])
+
+ if not single_entity:
+ labels = dict([(label, property(util.itemgetter(i))) for i, label in enumerate(labels) if label])
+ rowtuple = type.__new__(type, "RowTuple", (tuple,), labels)
+ rowtuple.keys = labels.keys
+
while True:
context.progress = util.Set()
context.partials = {}
if self._yield_per:
fetch = cursor.fetchmany(self._yield_per)
if not fetch:
- return
+ break
else:
fetch = cursor.fetchall()
elif single_entity:
rows = [process[0](context, row) for row in fetch]
else:
- rows = [tuple([proc(context, row) for proc in process]) for row in fetch]
+ rows = [rowtuple([proc(context, row) for proc in process]) for row in fetch]
if filter:
rows = filter(rows)
- if context.refresh_instance and context.only_load_props and context.refresh_instance in context.progress:
- context.refresh_instance.commit(context.only_load_props)
+ if context.refresh_instance and self._only_load_props and context.refresh_instance in context.progress:
+ context.refresh_instance.commit(self._only_load_props)
context.progress.remove(context.refresh_instance)
- for ii in context.progress:
- context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii)
- ii.commit_all()
-
+ session._finalize_loaded(context.progress)
+
for ii, attrs in context.partials.items():
- context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii, only_load_props=attrs)
ii.commit(attrs)
-
+
for row in rows:
yield row
def _get(self, key=None, ident=None, refresh_instance=None, lockmode=None, only_load_props=None):
lockmode = lockmode or self._lockmode
- if not self._populate_existing and not refresh_instance and not self.mapper.always_refresh and lockmode is None:
+ if not self._populate_existing and not refresh_instance and not self._mapper_zero().always_refresh and lockmode is None:
try:
- return self.session.identity_map[key]
+ instance = self.session.identity_map[key]
+ state = attributes.instance_state(instance)
+ if state.expired:
+ try:
+ state()
+ except orm_exc.ObjectDeletedError:
+ # TODO: should we expunge ? if so, should we expunge here ? or in mapper._load_scalar_attributes ?
+ self.session.expunge(instance)
+ return None
+ return instance
except KeyError:
pass
else:
ident = util.to_list(ident)
- q = self
-
- # dont use 'polymorphic' mapper if we are refreshing an instance
- if refresh_instance and q.mapper is not q.mapper:
- q = q.__reset_all(q.mapper, '_get')
+ if refresh_instance is None:
+ q = self.__no_criterion()
+ else:
+ q = self._clone()
if ident is not None:
- q = q.__no_criterion('get')
+ mapper = q._mapper_zero()
params = {}
- (_get_clause, _get_params) = q.mapper._get_clause
- q = q.filter(_get_clause)
- for i, primary_key in enumerate(q.mapper.primary_key):
+ (_get_clause, _get_params) = mapper._get_clause
+
+ _get_clause = q._adapt_clause(_get_clause, True, False)
+ q._criterion = _get_clause
+
+ for i, primary_key in enumerate(mapper.primary_key):
try:
params[_get_params[primary_key].key] = ident[i]
except IndexError:
- raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in q.mapper.primary_key]))
- q = q.params(params)
+ raise sa_exc.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in q.mapper.primary_key]))
+ q._params = params
if lockmode is not None:
- q = q.with_lockmode(lockmode)
- q = q.__get_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
+ q._lockmode = lockmode
+ q.__get_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
q._order_by = None
try:
# call using all() to avoid LIMIT compilation complexity
def _select_args(self):
return {'limit':self._limit, 'offset':self._offset, 'distinct':self._distinct, 'group_by':self._group_by or None, 'having':self._having or None}
_select_args = property(_select_args)
-
+
def _should_nest_selectable(self):
kwargs = self._select_args
return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False))
_should_nest_selectable = property(_should_nest_selectable)
- def count(self, whereclause=None, params=None, **kwargs):
- """Apply this query's criterion to a SELECT COUNT statement.
-
- the whereclause, params and \**kwargs arguments are deprecated. use filter()
- and other generative methods to establish modifiers.
-
- """
- q = self
- if whereclause is not None:
- q = q.filter(whereclause)
- if params is not None:
- q = q.params(params)
- q = q._legacy_select_kwargs(**kwargs)
- return q._count()
-
- def _count(self):
+ def count(self):
"""Apply this query's criterion to a SELECT COUNT statement.
this is the purely generative version which will become
the public method in version 0.5.
"""
- return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self.mapper.primary_key))
+ return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key))
def _col_aggregate(self, col, func, nested_cols=None):
whereclause = self._criterion
-
+
context = QueryContext(self)
- from_obj = self._from_obj
+ from_obj = self.__mapper_zero_from_obj()
if self._should_nest_selectable:
if not nested_cols:
s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s)
else:
s = sql.select([func(col)], whereclause, from_obj=from_obj, **self._select_args)
-
+
if self._autoflush and not self._populate_existing:
self.session._autoflush()
- return self.session.scalar(s, params=self._params, mapper=self.mapper)
+ return self.session.scalar(s, params=self._params, mapper=self._mapper_zero())
def compile(self):
"""compiles and returns a SQL statement based on the criterion and conditions within this Query."""
return self._compile_context().statement
- def _compile_context(self):
-
+ def _compile_context(self, labels=True):
context = QueryContext(self)
- if self._statement:
- self._statement.use_labels = True
- context.statement = self._statement
+ if context.statement:
return context
- from_obj = self._from_obj
- adapter = self._aliases_head
-
if self._lockmode:
try:
- for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
+ for_update = {'read': 'read',
+ 'update': True,
+ 'update_nowait': 'nowait',
+ None: False}[self._lockmode]
except KeyError:
- raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
+ raise sa_exc.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
else:
for_update = False
-
- context.from_clause = from_obj
- context.whereclause = self._criterion
- context.order_by = self._order_by
-
+
for entity in self._entities:
entity.setup_context(self, context)
-
- if self._eager_loaders and self._should_nest_selectable:
- # eager loaders are present, and the SELECT has limiting criterion
- # produce a "wrapped" selectable.
-
+
+ eager_joins = context.eager_joins.values()
+
+ if context.from_clause:
+ froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used
+ else:
+ froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM
+
+ if eager_joins and self._should_nest_selectable:
+ # for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select,
+ # then append eager joins onto that
+
if context.order_by:
- context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []]
- if adapter:
- context.order_by = adapter.adapt_list(context.order_by)
- # locate all embedded Column clauses so they can be added to the
- # "inner" select statement where they'll be available to the enclosing
- # statement's "order by"
- # TODO: this likely doesn't work with very involved ORDER BY expressions,
- # such as those including subqueries
order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by]))
else:
context.order_by = None
order_by_col_expr = []
-
- if adapter:
- context.primary_columns = adapter.adapt_list(context.primary_columns)
-
- inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=context.order_by, **self._select_args).alias()
- local_adapter = sql_util.ClauseAdapter(inner)
- context.row_adapter = mapperutil.create_row_adapter(inner, equivalent_columns=self.mapper._equivalent_columns)
+ inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args)
+
+ if self._correlate:
+ inner = inner.correlate(*self._correlate)
+
+ inner = inner.alias()
- statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=True)
+ equivs = self.__all_equivs()
- if context.eager_joins:
- eager_joins = local_adapter.traverse(context.eager_joins)
- statement.append_from(eager_joins)
+ context.adapter = sql_util.ColumnAdapter(inner, equivs)
+
+ statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=labels)
+
+ from_clause = inner
+ for eager_join in eager_joins:
+ # EagerLoader places a 'stop_on' attribute on the join,
+ # giving us a marker as to where the "splice point" of the join should be
+ from_clause = sql_util.splice_joins(from_clause, eager_join, eager_join.stop_on)
+
+ statement.append_from(from_clause)
if context.order_by:
+ local_adapter = sql_util.ClauseAdapter(inner)
statement.append_order_by(*local_adapter.copy_and_process(context.order_by))
statement.append_order_by(*context.eager_order_by)
else:
- if context.order_by:
- context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []]
- if adapter:
- context.order_by = adapter.adapt_list(context.order_by)
- else:
+ if not context.order_by:
context.order_by = None
-
- if adapter:
- context.primary_columns = adapter.adapt_list(context.primary_columns)
- context.row_adapter = mapperutil.create_row_adapter(adapter.alias, equivalent_columns=self.mapper._equivalent_columns)
-
+
if self._distinct and context.order_by:
order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by]))
context.primary_columns += order_by_col_expr
- statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=context.order_by, **self._select_args)
+ froms += context.eager_joins.values()
- if context.eager_joins:
- if adapter:
- context.eager_joins = adapter.adapt_clause(context.eager_joins)
- statement.append_from(context.eager_joins)
+ statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args)
+ if self._correlate:
+ statement = statement.correlate(*self._correlate)
if context.eager_order_by:
- if adapter:
- context.eager_order_by = adapter.adapt_list(context.eager_order_by)
statement.append_order_by(*context.eager_order_by)
- # polymorphic mappers which have concrete tables in their hierarchy usually
- # require row aliasing unconditionally.
- if not context.row_adapter and self.mapper._requires_row_aliasing:
- context.row_adapter = mapperutil.create_row_adapter(self.table, equivalent_columns=self.mapper._equivalent_columns)
-
- context.statement = statement
+ context.statement = statement._annotate({'_halt_adapt': True})
return context
def __str__(self):
return str(self.compile())
- # DEPRECATED LAND !
-
- def _generative_col_aggregate(self, col, func):
- """apply the given aggregate function to the query and return the newly
- resulting ``Query``. (deprecated)
- """
- if self._column_aggregate is not None:
- raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
- q = self.__no_statement("aggregate")
- q._column_aggregate = (col, func)
- return q
-
- def apply_min(self, col):
- """apply the SQL ``min()`` function against the given column to the
- query and return the newly resulting ``Query``.
-
- DEPRECATED.
- """
- return self._generative_col_aggregate(col, sql.func.min)
-
- def apply_max(self, col):
- """apply the SQL ``max()`` function against the given column to the
- query and return the newly resulting ``Query``.
-
- DEPRECATED.
- """
- return self._generative_col_aggregate(col, sql.func.max)
-
- def apply_sum(self, col):
- """apply the SQL ``sum()`` function against the given column to the
- query and return the newly resulting ``Query``.
-
- DEPRECATED.
- """
- return self._generative_col_aggregate(col, sql.func.sum)
-
- def apply_avg(self, col):
- """apply the SQL ``avg()`` function against the given column to the
- query and return the newly resulting ``Query``.
-
- DEPRECATED.
- """
- return self._generative_col_aggregate(col, sql.func.avg)
-
- def list(self): #pragma: no cover
- """DEPRECATED. use all()"""
-
- return list(self)
-
- def scalar(self): #pragma: no cover
- """DEPRECATED. use first()"""
-
- return self.first()
-
- def _legacy_filter_by(self, *args, **kwargs): #pragma: no cover
- return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint))
-
- def count_by(self, *args, **params): #pragma: no cover
- """DEPRECATED. use query.filter_by(\**params).count()"""
-
- return self.count(self.join_by(*args, **params))
-
- def select_whereclause(self, whereclause=None, params=None, **kwargs): #pragma: no cover
- """DEPRECATED. use query.filter(whereclause).all()"""
-
- q = self.filter(whereclause)._legacy_select_kwargs(**kwargs)
- if params is not None:
- q = q.params(params)
- return list(q)
+class _QueryEntity(object):
+ """represent an entity column returned within a Query result."""
- def _legacy_select_from(self, from_obj):
- q = self._clone()
- if len(from_obj) > 1:
- raise exceptions.ArgumentError("Multiple-entry from_obj parameter no longer supported")
- q._from_obj = from_obj[0]
- return q
+ def __new__(cls, *args, **kwargs):
+ if cls is _QueryEntity:
+ entity = args[1]
+ if _is_mapped_class(entity):
+ cls = _MapperEntity
+ else:
+ cls = _ColumnEntity
+ return object.__new__(cls)
- def _legacy_select_kwargs(self, **kwargs): #pragma: no cover
- q = self
- if "order_by" in kwargs and kwargs['order_by']:
- q = q.order_by(kwargs['order_by'])
- if "group_by" in kwargs:
- q = q.group_by(kwargs['group_by'])
- if "from_obj" in kwargs:
- q = q._legacy_select_from(kwargs['from_obj'])
- if "lockmode" in kwargs:
- q = q.with_lockmode(kwargs['lockmode'])
- if "distinct" in kwargs:
- q = q.distinct()
- if "limit" in kwargs:
- q = q.limit(kwargs['limit'])
- if "offset" in kwargs:
- q = q.offset(kwargs['offset'])
+ def _clone(self):
+ q = self.__class__.__new__(self.__class__)
+ q.__dict__ = self.__dict__.copy()
return q
+class _MapperEntity(_QueryEntity):
+ """mapper/class/AliasedClass entity"""
- def get_by(self, *args, **params): #pragma: no cover
- """DEPRECATED. use query.filter_by(\**params).first()"""
-
- ret = self._extension.get_by(self, *args, **params)
- if ret is not mapper.EXT_CONTINUE:
- return ret
-
- return self._legacy_filter_by(*args, **params).first()
-
- def select_by(self, *args, **params): #pragma: no cover
- """DEPRECATED. use use query.filter_by(\**params).all()."""
-
- ret = self._extension.select_by(self, *args, **params)
- if ret is not mapper.EXT_CONTINUE:
- return ret
-
- return self._legacy_filter_by(*args, **params).list()
-
- def join_by(self, *args, **params): #pragma: no cover
- """DEPRECATED. use join() to construct joins based on attribute names."""
+ def __init__(self, query, entity, entity_name=None):
+ self.primary_entity = not query._entities
+ query._entities.append(self)
- return self._legacy_join_by(args, params, start=self._joinpoint)
+ self.entities = [entity]
+ self.entity_zero = entity
+ self.entity_name = entity_name
- def _build_select(self, arg=None, params=None, **kwargs): #pragma: no cover
- if isinstance(arg, sql.FromClause) and arg.supports_execution():
- return self.from_statement(arg)
- elif arg is not None:
- return self.filter(arg)._legacy_select_kwargs(**kwargs)
+ def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic):
+ self.mapper = mapper
+ self.extension = self.mapper.extension
+ self.adapter = adapter
+ self.selectable = from_obj
+ self._with_polymorphic = with_polymorphic
+ self.is_aliased_class = is_aliased_class
+ if is_aliased_class:
+ self.path_entity = self.entity = self.entity_zero = entity
else:
- return self._legacy_select_kwargs(**kwargs)
-
- def selectfirst(self, arg=None, **kwargs): #pragma: no cover
- """DEPRECATED. use query.filter(whereclause).first()"""
-
- return self._build_select(arg, **kwargs).first()
-
- def selectone(self, arg=None, **kwargs): #pragma: no cover
- """DEPRECATED. use query.filter(whereclause).one()"""
-
- return self._build_select(arg, **kwargs).one()
-
- def select(self, arg=None, **kwargs): #pragma: no cover
- """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()"""
+ self.path_entity = mapper.base_mapper
+ self.entity = self.entity_zero = mapper
- ret = self._extension.select(self, arg=arg, **kwargs)
- if ret is not mapper.EXT_CONTINUE:
- return ret
- return self._build_select(arg, **kwargs).all()
-
- def execute(self, clauseelement, params=None, *args, **kwargs): #pragma: no cover
- """DEPRECATED. use query.from_statement().all()"""
-
- return self._select_statement(clauseelement, params, **kwargs)
-
- def select_statement(self, statement, **params): #pragma: no cover
- """DEPRECATED. Use query.from_statement(statement)"""
-
- return self._select_statement(statement, params)
-
- def select_text(self, text, **params): #pragma: no cover
- """DEPRECATED. Use query.from_statement(statement)"""
+ def set_with_polymorphic(self, query, cls_or_mappers, selectable):
+ if cls_or_mappers is None:
+ query._reset_polymorphic_adapter(self.mapper)
+ return
- return self._select_statement(text, params)
-
- def _select_statement(self, statement, params=None, **kwargs): #pragma: no cover
- q = self.from_statement(statement)
- if params is not None:
- q = q.params(params)
- q.__get_options(**kwargs)
- return list(q)
-
- def join_to(self, key): #pragma: no cover
- """DEPRECATED. use join() to create joins based on property names."""
-
- [keys, p] = self._locate_prop(key)
- return self.join_via(keys)
+ mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable)
+ self._with_polymorphic = mappers
- def join_via(self, keys): #pragma: no cover
- """DEPRECATED. use join() to create joins based on property names."""
+ # TODO: do the wrapped thing here too so that with_polymorphic() can be
+ # applied to aliases
+ if not self.is_aliased_class:
+ self.selectable = from_obj
+ self.adapter = query._get_polymorphic_adapter(self, from_obj)
- mapper = self._joinpoint
- clause = None
- for key in keys:
- prop = mapper.get_property(key, resolve_synonyms=True)
- if clause is None:
- clause = prop._get_join(mapper)
- else:
- clause &= prop._get_join(mapper)
- mapper = prop.mapper
+ def corresponds_to(self, entity):
+ if _is_aliased_class(entity):
+ return entity is self.path_entity
+ else:
+ return entity.base_mapper is self.path_entity
- return clause
+ def _get_entity_clauses(self, query, context):
- def _legacy_join_by(self, args, params, start=None): #pragma: no cover
- import properties
+ adapter = None
+ if not self.is_aliased_class and query._polymorphic_adapters:
+ for mapper in self.mapper.iterate_to_root():
+ adapter = query._polymorphic_adapters.get(mapper.mapped_table, None)
+ if adapter:
+ break
- clause = None
- for arg in args:
- if clause is None:
- clause = arg
- else:
- clause &= arg
+ if not adapter and self.adapter:
+ adapter = self.adapter
- for key, value in params.iteritems():
- (keys, prop) = self._locate_prop(key, start=start)
- if isinstance(prop, properties.PropertyLoader):
- c = prop.compare(operators.eq, value) & self.join_via(keys[:-1])
+ if adapter:
+ if query._from_obj_alias:
+ ret = adapter.wrap(query._from_obj_alias)
else:
- c = prop.compare(operators.eq, value) & self.join_via(keys)
- if clause is None:
- clause = c
- else:
- clause &= c
- return clause
-
- def _locate_prop(self, key, start=None): #pragma: no cover
- import properties
- keys = []
- seen = util.Set()
- def search_for_prop(mapper_):
- if mapper_ in seen:
- return None
- seen.add(mapper_)
-
- prop = mapper_.get_property(key, resolve_synonyms=True, raiseerr=False)
- if prop is not None:
- if isinstance(prop, properties.PropertyLoader):
- keys.insert(0, prop.key)
- return prop
- else:
- for prop in mapper_.iterate_properties:
- if not isinstance(prop, properties.PropertyLoader):
- continue
- x = search_for_prop(prop.mapper)
- if x:
- keys.insert(0, prop.key)
- return x
- else:
- return None
- p = search_for_prop(start or self.mapper)
- if p is None:
- raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key)
- return [keys, p]
-
- def selectfirst_by(self, *args, **params): #pragma: no cover
- """DEPRECATED. Use query.filter_by(\**kwargs).first()"""
-
- return self._legacy_filter_by(*args, **params).first()
-
- def selectone_by(self, *args, **params): #pragma: no cover
- """DEPRECATED. Use query.filter_by(\**kwargs).one()"""
-
- return self._legacy_filter_by(*args, **params).one()
-
- for deprecated_method in ('list', 'scalar', 'count_by',
- 'select_whereclause', 'get_by', 'select_by',
- 'join_by', 'selectfirst', 'selectone', 'select',
- 'execute', 'select_statement', 'select_text',
- 'join_to', 'join_via', 'selectfirst_by',
- 'selectone_by', 'apply_max', 'apply_min',
- 'apply_avg', 'apply_sum'):
- locals()[deprecated_method] = \
- util.deprecated(None, False)(locals()[deprecated_method])
-
-class _QueryEntity(object):
- """represent an entity column returned within a Query result."""
-
- def legacy_guess_type(self, e):
- if isinstance(e, type):
- return _MapperEntity(mapper=mapper.class_mapper(e))
- elif isinstance(e, mapper.Mapper):
- return _MapperEntity(mapper=e)
+ ret = adapter
else:
- return _ColumnEntity(column=e)
- legacy_guess_type=classmethod(legacy_guess_type)
+ ret = query._from_obj_alias
-class _MapperEntity(_QueryEntity):
- """entity column corresponding to mapped ORM instances."""
-
- def __init__(self, mapper, alias=None, id=None):
- self.mapper = mapper
- self.alias = alias
- self.alias_id = id
-
- def _get_entity_clauses(self, query):
- if self.alias:
- return self.alias
- elif self.alias_id:
- try:
- return query._alias_ids[self.alias_id][0]
- except KeyError:
- raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id)
-
- l = query._alias_ids.get(self.mapper)
- if l:
- if len(l) > 1:
- raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_entity()" % str(self.mapper))
- return l[0]
- else:
- return None
-
- def row_processor(self, query, context, single_entity):
- clauses = self._get_entity_clauses(query)
- if clauses:
- def proc(context, row):
- return self.mapper._instance(context, clauses.row_decorator(row), None)
- else:
- def proc(context, row):
- return self.mapper._instance(context, row, None)
-
- return proc
-
- def setup_context(self, query, context):
- clauses = self._get_entity_clauses(query)
- for value in self.mapper.iterate_properties:
- context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses)
+ return ret
- def __str__(self):
- return str(self.mapper)
+ def row_processor(self, query, context, custom_rows):
+ adapter = self._get_entity_clauses(query, context)
-class _PrimaryMapperEntity(_MapperEntity):
- """entity column corresponding to the 'primary' (first) mapped ORM instance."""
+ if context.adapter and adapter:
+ adapter = adapter.wrap(context.adapter)
+ elif not adapter:
+ adapter = context.adapter
- def row_processor(self, query, context, single_entity):
- if single_entity and 'append_result' in context.extension.methods:
+ # polymorphic mappers which have concrete tables in their hierarchy usually
+ # require row aliasing unconditionally.
+ if not adapter and self.mapper._requires_row_aliasing:
+ adapter = sql_util.ColumnAdapter(self.selectable, self.mapper._equivalent_columns)
+
+ if self.primary_entity:
+ _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter,
+ extension=self.extension, only_load_props=query._only_load_props, refresh_instance=context.refresh_instance
+ )
+ else:
+ _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter)
+
+ if custom_rows:
def main(context, row, result):
- if context.row_adapter:
- row = context.row_adapter(row)
- self.mapper._instance(context, row, result,
- extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
- )
- elif context.row_adapter:
- def main(context, row):
- return self.mapper._instance(context, context.row_adapter(row), None,
- extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
- )
+ _instance(row, result)
else:
def main(context, row):
- return self.mapper._instance(context, row, None,
- extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
- )
+ return _instance(row, None)
- return main
+ if self.is_aliased_class:
+ entname = self.entity._sa_label_name
+ else:
+ entname = self.mapper.class_.__name__
+
+ return main, entname
def setup_context(self, query, context):
# if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
# that we only load the appropriate types
if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
-
- if context.order_by is False:
- if self.mapper.order_by:
- context.order_by = self.mapper.order_by
- elif context.from_clause.default_order_by():
- context.order_by = context.from_clause.default_order_by()
-
- for value in self.mapper._iterate_polymorphic_properties(query._with_polymorphic, context.from_clause):
+
+ context.froms.append(self.selectable)
+
+ adapter = self._get_entity_clauses(query, context)
+
+ if self.primary_entity:
+ if context.order_by is False:
+ # the "default" ORDER BY use case applies only to "mapper zero". the "from clause" default should
+ # go away in 0.5 (or...maybe 0.6).
+ if self.mapper.order_by:
+ context.order_by = self.mapper.order_by
+ elif context.from_clause:
+ context.order_by = context.from_clause.default_order_by()
+ else:
+ context.order_by = self.selectable.default_order_by()
+ if context.order_by and adapter:
+ context.order_by = adapter.adapt_list(util.to_list(context.order_by))
+
+ for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic):
if query._only_load_props and value.key not in query._only_load_props:
continue
- context.exec_with_path(self.mapper, value.key, value.setup, context, only_load_props=query._only_load_props)
+ value.setup(context, self, (self.path_entity,), adapter, only_load_props=query._only_load_props, column_collection=context.primary_columns)
+
+ def __str__(self):
+ return str(self.mapper)
+
class _ColumnEntity(_QueryEntity):
- """entity column corresponding to Table or selectable columns."""
+ """Column/expression based entity."""
+
+ def __init__(self, query, column, entity_name=None):
+ if isinstance(column, expression.FromClause) and not isinstance(column, expression.ColumnElement):
+ for c in column.c:
+ _ColumnEntity(query, c)
+ return
+
+ query._entities.append(self)
- def __init__(self, column, id):
if isinstance(column, basestring):
column = sql.literal_column(column)
-
- if column and isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
+ elif isinstance(column, (attributes.QueryableAttribute, mapper.Mapper._CompileOnAttr)):
+ column = column.__clause_element__()
+ elif not isinstance(column, sql.ColumnElement):
+ raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column)
+
+ if not hasattr(column, '_label'):
column = column.label(None)
+
self.column = column
- self.alias_id = id
+ self.entity_name = None
+ self.froms = util.Set()
+ self.entities = util.OrderedSet([elem._annotations['parententity'] for elem in visitors.iterate(column, {}) if 'parententity' in elem._annotations])
+ if self.entities:
+ self.entity_zero = list(self.entities)[0]
+ else:
+ self.entity_zero = None
+
+ def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic):
+ self.froms.add(from_obj)
def __resolve_expr_against_query_aliases(self, query, expr, context):
- if not query._alias_ids:
- return expr
-
- if ('_ColumnEntity', expr) in context.attributes:
- return context.attributes[('_ColumnEntity', expr)]
-
- if self.alias_id:
- try:
- aliases = query._alias_ids[self.alias_id][0]
- except KeyError:
- raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id)
+ return query._adapt_clause(expr, False, True)
- def _locate_aliased(element):
- if element in query._alias_ids:
- return aliases
- else:
- def _locate_aliased(element):
- if element in query._alias_ids:
- aliases = query._alias_ids[element]
- if len(aliases) > 1:
- raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_column(), or use the aliased() function to use explicit class aliases." % expr)
- return aliases[0]
- return None
-
- class Adapter(visitors.ClauseVisitor):
- def before_clone(self, element):
- if isinstance(element, expression.FromClause):
- alias = _locate_aliased(element)
- if alias:
- return alias.alias
-
- if hasattr(element, 'table'):
- alias = _locate_aliased(element.table)
- if alias:
- return alias.aliased_column(element)
+ def row_processor(self, query, context, custom_rows):
+ column = self.__resolve_expr_against_query_aliases(query, self.column, context)
- return None
+ if context.adapter:
+ column = context.adapter.columns[column]
- context.attributes[('_ColumnEntity', expr)] = ret = Adapter().traverse(expr, clone=True)
- return ret
-
- def row_processor(self, query, context, single_entity):
- column = self.__resolve_expr_against_query_aliases(query, self.column, context)
def proc(context, row):
return row[column]
- return proc
-
+
+ return (proc, getattr(column, 'name', None))
+
def setup_context(self, query, context):
column = self.__resolve_expr_against_query_aliases(query, self.column, context)
- context.secondary_columns.append(column)
-
+ context.froms += list(self.froms)
+ context.primary_columns.append(column)
+
def __str__(self):
return str(self.column)
-
-Query.logger = logging.class_logger(Query)
+Query.logger = log.class_logger(Query)
class QueryContext(object):
def __init__(self, query):
+
+ if query._statement:
+ if isinstance(query._statement, expression._SelectBaseMixin) and not query._statement.use_labels:
+ self.statement = query._statement.apply_labels()
+ else:
+ self.statement = query._statement
+ else:
+ self.statement = None
+ self.from_clause = query._from_obj
+ self.whereclause = query._criterion
+ self.order_by = query._order_by
+ if self.order_by:
+ self.order_by = [expression._literal_as_text(o) for o in util.to_list(self.order_by)]
+
self.query = query
- self.mapper = query.mapper
self.session = query.session
- self.extension = query._extension
- self.statement = None
- self.row_adapter = None
self.populate_existing = query._populate_existing
self.version_check = query._version_check
- self.only_load_props = query._only_load_props
self.refresh_instance = query._refresh_instance
- self.path = ()
self.primary_columns = []
self.secondary_columns = []
self.eager_order_by = []
- self.eager_joins = None
+
+ self.eager_joins = {}
+ self.froms = []
+ self.adapter = None
+
self.options = query._with_options
self.attributes = query._attributes.copy()
- def exec_with_path(self, mapper, propkey, fn, *args, **kwargs):
- oldpath = self.path
- self.path += (mapper.base_mapper, propkey)
- try:
- return fn(*args, **kwargs)
- finally:
- self.path = oldpath
+class AliasOption(interfaces.MapperOption):
+ def __init__(self, alias):
+ self.alias = alias
+ def process_query(self, query):
+ if isinstance(self.alias, basestring):
+ alias = query._mapper_zero().mapped_table.alias(self.alias)
+ else:
+ alias = self.alias
+ query._from_obj_alias = sql_util.ColumnAdapter(alias)
+
_runid = 1L
_id_lock = util.threading.Lock()
+# scoping.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import inspect
+import types
+
+import sqlalchemy.exceptions as sa_exc
from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs
-from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, class_mapper
+from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, \
+ class_mapper
from sqlalchemy.orm.session import Session
-from sqlalchemy import exceptions
-import types
__all__ = ['ScopedSession']
scope = kwargs.pop('scope', False)
if scope is not None:
if self.registry.has():
- raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
+ raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
else:
sess = self.session_factory(**kwargs)
self.registry.set(sess)
from sqlalchemy.orm import mapper
- extension_args = dict([(arg,kwargs.pop(arg))
+ extension_args = dict([(arg, kwargs.pop(arg))
for arg in get_cls_kwargs(_ScopedExt)
if arg in kwargs])
setattr(ScopedSession, prop, makeprop(prop))
def clslevel(name):
- def do(cls, *args,**kwargs):
+ def do(cls, *args, **kwargs):
return getattr(Session, name)(*args, **kwargs)
return classmethod(do)
-for prop in ('close_all','object_session', 'identity_key'):
+for prop in ('close_all', 'object_session', 'identity_key'):
setattr(ScopedSession, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
self.context = context
self.validate = validate
self.save_on_init = save_on_init
+ self.set_kwargs_on_init = None
def validating(self):
return _ScopedExt(self.context, validate=True)
def configure(self, **kwargs):
return _ScopedExt(self.context, **kwargs)
- def get_session(self):
- return self.context.registry()
-
def instrument_class(self, mapper, class_):
class query(object):
def __getattr__(s, key):
return getattr(self.context.registry().query(class_), key)
def __call__(s):
return self.context.registry().query(class_)
-
+ def __get__(self, instance, cls):
+ return self
+
if not 'query' in class_.__dict__:
class_.query = query()
-
+
+ if self.set_kwargs_on_init is None:
+ self.set_kwargs_on_init = class_.__init__ is object.__init__
+ if self.set_kwargs_on_init:
+ def __init__(self, **kwargs):
+ pass
+ class_.__init__ = __init__
+
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
if self.save_on_init:
entity_name = kwargs.pop('_sa_entity_name', None)
session = kwargs.pop('_sa_session', None)
- if not isinstance(oldinit, types.MethodType):
+
+ if self.set_kwargs_on_init:
for key, value in kwargs.items():
if self.validate:
- if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
- raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+ if not mapper.get_property(key, resolve_synonyms=False,
+ raiseerr=False):
+ raise sa_exc.ArgumentError(
+ "Invalid __init__ argument: '%s'" % key)
setattr(instance, key, value)
kwargs.clear()
+
if self.save_on_init:
session = session or self.context.registry()
- session._save_impl(instance, entity_name=entity_name)
+ session._save_without_cascade(instance, entity_name=entity_name)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
- object_session(instance).expunge(instance)
+ sess = object_session(instance)
+ if sess:
+ sess.expunge(instance)
return EXT_CONTINUE
def dispose_class(self, mapper, class_):
"""Provides the Session class and related utilities."""
-
import weakref
-from sqlalchemy import util, exceptions, sql, engine
-from sqlalchemy.orm import unitofwork, query, attributes, util as mapperutil
-from sqlalchemy.orm.mapper import object_mapper as _object_mapper
-from sqlalchemy.orm.mapper import class_mapper as _class_mapper
-from sqlalchemy.orm.mapper import Mapper
+import sqlalchemy.exceptions as sa_exc
+import sqlalchemy.orm.attributes
+from sqlalchemy import util, sql, engine
+from sqlalchemy.sql import util as sql_util, expression
+from sqlalchemy.orm import exc, unitofwork, query, attributes, \
+ util as mapperutil, SessionExtension
+from sqlalchemy.orm.util import object_mapper as _object_mapper
+from sqlalchemy.orm.util import class_mapper as _class_mapper
+from sqlalchemy.orm.util import _state_mapper, _state_has_identity, _class_to_mapper
+from sqlalchemy.orm.mapper import Mapper
+from sqlalchemy.orm.unitofwork import UOWTransaction
+from sqlalchemy.orm import identity
__all__ = ['Session', 'SessionTransaction', 'SessionExtension']
-def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs):
+def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, autoexpire=True, **kwargs):
"""Generate a custom-configured [sqlalchemy.orm.session#Session] class.
The returned object is a subclass of ``Session``, which, when instantiated with no
sess = Session()
- The function features a single keyword argument of its own, `class_`, which
- may be used to specify an alternate class other than ``sqlalchemy.orm.session.Session``
- which should be used by the returned class. All other keyword arguments sent to
- `sessionmaker()` are passed through to the instantiated `Session()` object.
- """
+ Options:
+
+ autocommit
+ Defaults to ``False``. When ``True``, the ``Session`` does not keep a
+ persistent transaction running, and will acquire connections from the engine
+ on an as-needed basis, returning them immediately after their use. Flushes
+ will begin and commit (or possibly rollback) their own transaction if no
+ transaction is present. When using this mode, the `session.begin()` method
+ may be used to begin a transaction explicitly.
+
+ Leaving it on its default value of ``False`` means that the ``Session`` will
+ acquire a connection and begin a transaction the first time it is used, which
+ it will maintain persistently until ``rollback()``, ``commit()``, or
+ ``close()`` is called. When the transaction is released by any of these
+ methods, the ``Session`` is ready for the next usage, which will again acquire
+ and maintain a new connection/transaction.
+
+ autoexpire
+ When ``True``, all instances will be fully expired after each ``rollback()``
+ and after each ``commit()``, so that all attribute/object access subsequent
+ to a completed transaction will load from the most recent database state.
+
+ autoflush
+ When ``True``, all query operations will issue a ``flush()`` call to this
+ ``Session`` before proceeding. This is a convenience feature so that
+ ``flush()`` need not be called repeatedly in order for database queries to
+ retrieve results. It's typical that ``autoflush`` is used in conjunction with
+ ``autocommit=False``. In this scenario, explicit calls to ``flush()`` are rarely
+ needed; you usually only need to call ``commit()`` (which flushes) to finalize
+ changes.
+
+ bind
+ An optional ``Engine`` or ``Connection`` to which this ``Session`` should be
+ bound. When specified, all SQL operations performed by this session will
+ execute via this connectable.
+
+ binds
+ An optional dictionary, which contains more granular "bind" information than
+ the ``bind`` parameter provides. This dictionary can map individual ``Table``
+ instances as well as ``Mapper`` instances to individual ``Engine`` or
+ ``Connection`` objects. Operations which proceed relative to a particular
+ ``Mapper`` will consult this dictionary for the direct ``Mapper`` instance as
+ well as the mapper's ``mapped_table`` attribute in order to locate an
+ connectable to use. The full resolution is described in the ``get_bind()``
+ method of ``Session``. Usage looks like::
+
+ sess = Session(binds={
+ SomeMappedClass : create_engine('postgres://engine1'),
+ somemapper : create_engine('postgres://engine2'),
+ some_table : create_engine('postgres://engine3'),
+ })
+
+ Also see the ``bind_mapper()`` and ``bind_table()`` methods.
+
+ \class_
+ Specify an alternate class other than ``sqlalchemy.orm.session.Session``
+ which should be used by the returned class. This is the only argument
+ that is local to the ``sessionmaker()`` function, and is not sent
+ directly to the constructor for ``Session``.
+ echo_uow
+ When ``True``, configure Python logging to dump all unit-of-work
+ transactions. This is the equivalent of
+ ``logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.DEBUG)``.
+
+ extension
+ An optional [sqlalchemy.orm.session#SessionExtension] instance, which will receive
+ pre- and post- commit and flush events, as well as a post-rollback event. User-
+ defined code may be placed within these hooks using a user-defined subclass
+ of ``SessionExtension``.
+
+ twophase
+ When ``True``, all transactions will be started using
+ [sqlalchemy.engine_TwoPhaseTransaction]. During a ``commit()``, after
+ ``flush()`` has been issued for all attached databases, the ``prepare()``
+ method on each database's ``TwoPhaseTransaction`` will be called. This allows
+ each database to roll back the entire transaction, before each transaction is
+ committed.
+
+ weak_identity_map
+ When set to the default value of ``False``, a weak-referencing map is used;
+ instances which are not externally referenced will be garbage collected
+ immediately. For dereferenced instances which have pending changes present,
+ the attribute management system will create a temporary strong-reference to
+ the object which lasts until the changes are flushed to the database, at which
+ point it's again dereferenced. Alternatively, when using the value ``True``,
+ the identity map uses a regular Python dictionary to store instances. The
+ session will maintain all instances present until they are removed using
+ expunge(), clear(), or purge().
+
+ """
+
+ if 'transactional' in kwargs:
+ util.warn_deprecated("The 'transactional' argument to sessionmaker() is deprecated; use autocommit=True|False instead.")
+ autocommit = not kwargs.pop('transactional')
+
kwargs['bind'] = bind
kwargs['autoflush'] = autoflush
- kwargs['transactional'] = transactional
+ kwargs['autocommit'] = autocommit
+ kwargs['autoexpire'] = autoexpire
if class_ is None:
class_ = Session
- class Sess(class_):
+ class Sess(object):
def __init__(self, **local_kwargs):
for k in kwargs:
local_kwargs.setdefault(k, kwargs[k])
kwargs.update(new_kwargs)
configure = classmethod(configure)
+ s = type.__new__(type, "Session", (Sess, class_), {})
+ return s
- return Sess
-
-class SessionExtension(object):
- """An extension hook object for Sessions. Subclasses may be installed into a Session
- (or sessionmaker) using the ``extension`` keyword argument.
- """
-
- def before_commit(self, session):
- """Execute right before commit is called.
-
- Note that this may not be per-flush if a longer running transaction is ongoing."""
-
- def after_commit(self, session):
- """Execute after a commit has occured.
-
- Note that this may not be per-flush if a longer running transaction is ongoing."""
-
- def after_rollback(self, session):
- """Execute after a rollback has occured.
-
- Note that this may not be per-flush if a longer running transaction is ongoing."""
-
- def before_flush(self, session, flush_context, instances):
- """Execute before flush process has started.
-
- `instances` is an optional list of objects which were passed to the ``flush()``
- method.
- """
-
- def after_flush(self, session, flush_context):
- """Execute after flush has completed, but before commit has been called.
-
- Note that the session's state is still in pre-flush, i.e. 'new', 'dirty',
- and 'deleted' lists still show pre-flush state as well as the history
- settings on instance attributes."""
-
- def after_flush_postexec(self, session, flush_context):
- """Execute after flush has completed, and after the post-exec state occurs.
-
- This will be when the 'new', 'dirty', and 'deleted' lists are in their final
- state. An actual commit() may or may not have occured, depending on whether or not
- the flush started its own transaction or participated in a larger transaction.
- """
-
- def after_begin(self, session, transaction, connection):
- """Execute after a transaction is begun on a connection
-
- `transaction` is the SessionTransaction. This method is called after an
- engine level transaction is begun on a connection.
- """
class SessionTransaction(object):
"""Represents a Session-level Transaction.
self.nested = nested
self._active = True
self._prepared = False
+ if not parent and nested:
+ raise sa_exc.InvalidRequestError("Can't start a SAVEPOINT transaction when no existing transaction is in progress")
+ self._take_snapshot()
- is_active = property(lambda s: s.session is not None and s._active)
+ def is_active(self):
+ return self.session is not None and self._active
+ is_active = property(is_active)
def _assert_is_active(self):
self._assert_is_open()
if not self._active:
- raise exceptions.InvalidRequestError("The transaction is inactive due to a rollback in a subtransaction and should be closed")
+ raise sa_exc.InvalidRequestError("The transaction is inactive due to a rollback in a subtransaction. Issue rollback() to cancel the transaction.")
def _assert_is_open(self):
if self.session is None:
- raise exceptions.InvalidRequestError("The transaction is closed")
-
+ raise sa_exc.InvalidRequestError("The transaction is closed")
+
+ def _is_transaction_boundary(self):
+ return self.nested or not self._parent
+ _is_transaction_boundary = property(_is_transaction_boundary)
+
def connection(self, bindkey, **kwargs):
self._assert_is_active()
engine = self.session.get_bind(bindkey, **kwargs)
- return self.get_or_add(engine)
+ return self._connection_for_bind(engine)
- def _begin(self, **kwargs):
+ def _begin(self, autoflush=True, nested=False):
self._assert_is_active()
- return SessionTransaction(self.session, self, **kwargs)
+ return SessionTransaction(self.session, self, autoflush=autoflush, nested=nested)
def _iterate_parents(self, upto=None):
if self._parent is upto:
return (self,)
else:
if self._parent is None:
- raise exceptions.InvalidRequestError("Transaction %s is not on the active transaction list" % upto)
+ raise sa_exc.InvalidRequestError("Transaction %s is not on the active transaction list" % upto)
return (self,) + self._parent._iterate_parents(upto)
+
+ def _take_snapshot(self):
+ if not self._is_transaction_boundary:
+ self._new = self._parent._new
+ self._deleted = self._parent._deleted
+ return
+
+ if self.nested:
+ self.session.flush()
+
+ if self.autoflush:
+ # TODO: the "dirty_states" assertion is expensive,
+ # so consider these assertions as temporary
+ # during development
+ assert not self.session._new
+ assert not self.session._deleted
+ assert not self.session._dirty_states
+
+ self._new = weakref.WeakKeyDictionary()
+ self._deleted = weakref.WeakKeyDictionary()
+
+ def _restore_snapshot(self):
+ assert self._is_transaction_boundary
+
+ for s in util.Set(self._deleted).union(self.session._deleted):
+ self.session._update_impl(s)
+
+ assert not self.session._deleted
+
+ for s in util.Set(self._new).union(self.session._new):
+ self.session._expunge_state(s)
+
+ for s in self.session.identity_map.all_states():
+ _expire_state(s, None)
+
+ def _remove_snapshot(self):
+ assert self._is_transaction_boundary
- def add(self, bind):
- self._assert_is_active()
- if self._parent is not None and not self.nested:
- return self._parent.add(bind)
-
- if bind.engine in self._connections:
- raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or ""))
- return self.get_or_add(bind)
-
- def get_or_add(self, bind):
+ if not self.nested and self.session.autoexpire:
+ for s in self.session.identity_map.all_states():
+ _expire_state(s, None)
+
+ def _connection_for_bind(self, bind):
self._assert_is_active()
if bind in self._connections:
return self._connections[bind][0]
- if self._parent is not None:
- conn = self._parent.get_or_add(bind)
+ if self._parent:
+ conn = self._parent._connection_for_bind(bind)
if not self.nested:
return conn
else:
if isinstance(bind, engine.Connection):
conn = bind
if conn.engine in self._connections:
- raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
+ raise sa_exc.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
else:
conn = bind.contextual_connect()
def prepare(self):
if self._parent is not None or not self.session.twophase:
- raise exceptions.InvalidRequestError("Only root two phase transactions of can be prepared")
+ raise sa_exc.InvalidRequestError("Only root two phase transactions of can be prepared")
self._prepare_impl()
-
+
def _prepare_impl(self):
self._assert_is_active()
if self.session.extension is not None and (self._parent is None or self.nested):
if self.session.extension is not None:
self.session.extension.after_commit(self.session)
-
+
+ self._remove_snapshot()
+
self.close()
return self._parent
-
+
def rollback(self):
self._assert_is_open()
for t in util.Set(self._connections.values()):
t[1].rollback()
+ self._restore_snapshot()
+
if self.session.extension is not None:
self.session.extension.after_rollback(self.session)
self._deactivate()
self.session = None
self._connections = None
-
+
def __enter__(self):
return self
* *Transient* - an instance that's not in a session, and is not saved to the database;
i.e. it has no database identity. The only relationship such an object has to the ORM
- is that its class has a `mapper()` associated with it.
+ is that its class has a ``mapper()`` associated with it.
- * *Pending* - when you `save()` a transient instance, it becomes pending. It still
+ * *Pending* - when you ``add()`` a transient instance, it becomes pending. It still
wasn't actually flushed to the database yet, but it will be when the next flush
occurs.
they're detached, **except** they will not be able to issue any SQL in order to load
collections or attributes which are not yet loaded, or were marked as "expired".
- The session methods which control instance state include ``save()``, ``update()``,
- ``save_or_update()``, ``delete()``, ``merge()``, and ``expunge()``.
+ The session methods which control instance state include ``add()``, ``delete()``,
+ ``merge()``, and ``expunge()``.
- The Session object is **not** threadsafe, particularly during flush operations. A session
- which is only read from (i.e. is never flushed) can be used by concurrent threads if it's
- acceptable that some object instances may be loaded twice.
+ The Session object is generally **not** threadsafe. A session which is set to ``autocommit``
+ and is only read from may be used by concurrent threads if it's acceptable that some object
+ instances may be loaded twice.
The typical pattern to managing Sessions in a multi-threaded environment is either to use
mutexes to limit concurrent access to one thread at a time, or more commonly to establish
a unique session for every thread, using a threadlocal variable. SQLAlchemy provides
a thread-managed Session adapter, provided by the [sqlalchemy.orm#scoped_session()] function.
+
"""
-
- def __init__(self, bind=None, autoflush=True, transactional=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None):
+ def __init__(self, bind=None, autoflush=True, autoexpire=True, autocommit=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None):
"""Construct a new Session.
- A session is usually constructed using the [sqlalchemy.orm#create_session()] function,
- or its more "automated" variant [sqlalchemy.orm#sessionmaker()].
-
- autoflush
- When ``True``, all query operations will issue a ``flush()`` call to this
- ``Session`` before proceeding. This is a convenience feature so that
- ``flush()`` need not be called repeatedly in order for database queries to
- retrieve results. It's typical that ``autoflush`` is used in conjunction with
- ``transactional=True``, so that ``flush()`` is never called; you just call
- ``commit()`` when changes are complete to finalize all changes to the
- database.
-
- bind
- An optional ``Engine`` or ``Connection`` to which this ``Session`` should be
- bound. When specified, all SQL operations performed by this session will
- execute via this connectable.
-
- binds
- An optional dictionary, which contains more granular "bind" information than
- the ``bind`` parameter provides. This dictionary can map individual ``Table``
- instances as well as ``Mapper`` instances to individual ``Engine`` or
- ``Connection`` objects. Operations which proceed relative to a particular
- ``Mapper`` will consult this dictionary for the direct ``Mapper`` instance as
- well as the mapper's ``mapped_table`` attribute in order to locate an
- connectable to use. The full resolution is described in the ``get_bind()``
- method of ``Session``. Usage looks like::
-
- sess = Session(binds={
- SomeMappedClass : create_engine('postgres://engine1'),
- somemapper : create_engine('postgres://engine2'),
- some_table : create_engine('postgres://engine3'),
- })
-
- Also see the ``bind_mapper()`` and ``bind_table()`` methods.
-
- echo_uow
- When ``True``, configure Python logging to dump all unit-of-work
- transactions. This is the equivalent of
- ``logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.DEBUG)``.
-
- extension
- An optional [sqlalchemy.orm.session#SessionExtension] instance, which will receive
- pre- and post- commit and flush events, as well as a post-rollback event. User-
- defined code may be placed within these hooks using a user-defined subclass
- of ``SessionExtension``.
-
- transactional
- Set up this ``Session`` to automatically begin transactions. Setting this
- flag to ``True`` is the rough equivalent of calling ``begin()`` after each
- ``commit()`` operation, after each ``rollback()``, and after each
- ``close()``. Basically, this has the effect that all session operations are
- performed within the context of a transaction. Note that the ``begin()``
- operation does not immediately utilize any connection resources; only when
- connection resources are first required do they get allocated into a
- transactional context.
-
- twophase
- When ``True``, all transactions will be started using
- [sqlalchemy.engine_TwoPhaseTransaction]. During a ``commit()``, after
- ``flush()`` has been issued for all attached databases, the ``prepare()``
- method on each database's ``TwoPhaseTransaction`` will be called. This allows
- each database to roll back the entire transaction, before each transaction is
- committed.
-
- weak_identity_map
- When set to the default value of ``False``, a weak-referencing map is used;
- instances which are not externally referenced will be garbage collected
- immediately. For dereferenced instances which have pending changes present,
- the attribute management system will create a temporary strong-reference to
- the object which lasts until the changes are flushed to the database, at which
- point it's again dereferenced. Alternatively, when using the value ``True``,
- the identity map uses a regular Python dictionary to store instances. The
- session will maintain all instances present until they are removed using
- expunge(), clear(), or purge().
+ Arguments to ``Session`` are described using the [sqlalchemy.orm#sessionmaker()] function.
+
"""
self.echo_uow = echo_uow
- self.weak_identity_map = weak_identity_map
- self.uow = unitofwork.UnitOfWork(self)
- self.identity_map = self.uow.identity_map
+ if weak_identity_map:
+ self._identity_cls = identity.WeakInstanceDict
+ else:
+ self._identity_cls = identity.StrongInstanceDict
+ self.identity_map = self._identity_cls()
+ self._new = {} # InstanceState->object, strong refs object
+ self._deleted = {} # same
self.bind = bind
self.__binds = {}
self.transaction = None
self.hash_key = id(self)
self.autoflush = autoflush
- self.transactional = transactional
+ self.autocommit = autocommit
+ self.autoexpire = autoexpire
self.twophase = twophase
self.extension = extension
self._query_cls = query.Query
for t in mapperortable._all_tables:
self.__binds[t] = value
- if self.transactional:
+ if not self.autocommit:
self.begin()
_sessions[self.hash_key] = self
- def begin(self, **kwargs):
- """Begin a transaction on this Session."""
-
+ def begin(self, subtransactions=False, nested=False, _autoflush=True):
+ """Begin a transaction on this Session.
+
+ If this Session is already within a transaction,
+ either a plain transaction or nested transaction,
+ an error is raised, unless ``subtransactions=True``
+ or ``nested=True`` is specified.
+
+ The ``subtransactions=True`` flag indicates that
+ this ``begin()`` can create a subtransaction if a
+ transaction is already in progress. A subtransaction
+ is a non-transactional, delimiting construct that
+ allows matching begin()/commit() pairs to be nested
+ together, with only the outermost begin/commit pair
+ actually affecting transactional state. When a rollback
+ is issued, the subtransaction will directly roll back
+ the innermost real transaction, however each subtransaction
+ still must be explicitly rolled back to maintain proper
+ stacking of subtransactions.
+
+ If no transaction is in progress,
+ then a real transaction is begun.
+
+ The ``nested`` flag begins a SAVEPOINT transaction
+ and is equivalent to calling ``begin_nested()``.
+
+ """
if self.transaction is not None:
- self.transaction = self.transaction._begin(**kwargs)
+ if subtransactions or nested:
+ self.transaction = self.transaction._begin(nested=nested, autoflush=_autoflush)
+ else:
+ raise sa_exc.InvalidRequestError("A transaction is already begun. Use subtransactions=True to allow subtransactions.")
else:
- self.transaction = SessionTransaction(self, **kwargs)
- return self.transaction
-
- create_transaction = begin
+ self.transaction = SessionTransaction(self, nested=nested, autoflush=_autoflush)
+ return self.transaction # needed for __enter__/__exit__ hook
def begin_nested(self):
"""Begin a `nested` transaction on this Session.
This utilizes a ``SAVEPOINT`` transaction for databases
which support this feature.
- """
+ The nested transaction is a real transation, unlike
+ a "subtransaction" which corresponds to multiple
+ ``begin()`` calls. The next ``rollback()`` or
+ ``commit()`` call will operate upon this nested
+ transaction.
+
+ """
return self.begin(nested=True)
def rollback(self):
If no transaction is in progress, this method is a
pass-thru.
+
+ This method rolls back the current transaction
+ or nested transaction regardless of subtransactions
+ being in effect. All subtrasactions up to the
+ first real transaction are closed. Subtransactions
+ occur when begin() is called mulitple times.
+
"""
-
if self.transaction is None:
pass
else:
self.transaction.rollback()
- # TODO: we can rollback attribute values. however
- # we would want to expand attributes.py to be able to save *two* rollback points, one to the
- # last flush() and the other to when the object first entered the transaction.
- # [ticket:705]
- #attributes.rollback(*self.identity_map.values())
- if self.transaction is None and self.transactional:
+ if self.transaction is None and not self.autocommit:
self.begin()
def commit(self):
- """Commit the current transaction in progress.
+ """Flush any pending changes, and commit the current transaction
+ in progress, assuming no subtransactions are in effect.
If no transaction is in progress, this method raises
an InvalidRequestError.
+
+ If a subtransaction is in effect (which occurs when
+ begin() is called multiple times), the subtransaction
+ will be closed, and the next call to ``commit()``
+ will operate on the enclosing transaction.
- If the ``begin()`` method was called on this ``Session``
- additional times subsequent to its first call,
- ``commit()`` will not actually commit, and instead
- pops an internal SessionTransaction off its internal stack
- of transactions. Only when the "root" SessionTransaction
- is reached does an actual database-level commit occur.
- """
+ For a session configured with autocommit=False, a new
+ transaction will be begun immediately after the commit,
+ but note that the newly begun transaction does *not*
+ use any connection resources until the first SQL is
+ actually emitted.
+ """
if self.transaction is None:
- if self.transactional:
+ if not self.autocommit:
self.begin()
else:
- raise exceptions.InvalidRequestError("No transaction is begun.")
+ raise sa_exc.InvalidRequestError("No transaction is begun.")
self.transaction.commit()
- if self.transaction is None and self.transactional:
+ if self.transaction is None and not self.autocommit:
self.begin()
def prepare(self):
not such, an InvalidRequestError is raised.
"""
if self.transaction is None:
- if self.transactional:
+ if not self.autocommit:
self.begin()
else:
- raise exceptions.InvalidRequestError("No transaction is begun.")
+ raise sa_exc.InvalidRequestError("No transaction is begun.")
self.transaction.prepare()
def __connection(self, engine, **kwargs):
if self.transaction is not None:
- return self.transaction.get_or_add(engine)
+ return self.transaction._connection_for_bind(engine)
else:
return engine.contextual_connect(**kwargs)
the proper bind, in the case of ShardedSession.
"""
+ clause = expression._literal_as_text(clause)
+
engine = self.get_bind(mapper, clause=clause, instance=instance)
return self.__connection(engine, close_with_result=True).execute(clause, params or {})
if self.transaction is not None:
for transaction in self.transaction._iterate_parents():
transaction.close()
- if self.transactional:
+ if not self.autocommit:
# note this doesnt use any connection resources
self.begin()
sess.close()
close_all = classmethod(close_all)
- def clear(self):
+ def expunge_all(self):
"""Remove all object instances from this ``Session``.
This is equivalent to calling ``expunge()`` for all objects in
this ``Session``.
"""
- for instance in self:
- self._unattach(instance)
- self.uow = unitofwork.UnitOfWork(self)
- self.identity_map = self.uow.identity_map
+ for state in self.identity_map.all_states() + list(self._new):
+ del state.session_id
+ self.identity_map = self._identity_cls()
+ self._new = {}
+ self._deleted = {}
+ clear = expunge_all
+
+ # TODO: deprecate
+ #clear = util.deprecated()(expunge_all)
+
# TODO: need much more test coverage for bind_mapper() and similar !
def bind_mapper(self, mapper, bind, entity_name=None):
"""
if mapper is None and clause is None:
- if self.bind is not None:
+ if self.bind:
return self.bind
else:
- raise exceptions.UnboundExecutionError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
+ raise sa_exc.UnboundExecutionError("This session is not bound to any Engine or Connection; specify a mapper to get_bind()")
- elif len(self.__binds):
- if mapper is not None:
- if isinstance(mapper, type):
- mapper = _class_mapper(mapper)
+ elif self.__binds:
+ if mapper:
+ mapper = _class_to_mapper(mapper)
if mapper.base_mapper in self.__binds:
return self.__binds[mapper.base_mapper]
- elif mapper.compile().mapped_table in self.__binds:
+ elif mapper.mapped_table in self.__binds:
return self.__binds[mapper.mapped_table]
- if clause is not None:
- for t in clause._table_iterator():
+ if clause:
+ for t in sql_util.find_tables(clause):
if t in self.__binds:
return self.__binds[t]
- if self.bind is not None:
+ if self.bind:
return self.bind
- elif isinstance(clause, sql.expression.ClauseElement) and clause.bind is not None:
+ elif isinstance(clause, sql.expression.ClauseElement) and clause.bind:
return clause.bind
- elif mapper is None:
- raise exceptions.UnboundExecutionError("Could not locate any mapper associated with SQL expression")
+ elif not mapper:
+ raise sa_exc.UnboundExecutionError("Could not locate any mapper associated with SQL expression")
else:
- if isinstance(mapper, type):
- mapper = _class_mapper(mapper)
- else:
- mapper = mapper.compile()
+ mapper = _class_to_mapper(mapper)
e = mapper.mapped_table.bind
if e is None:
- raise exceptions.UnboundExecutionError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
+ raise sa_exc.UnboundExecutionError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
return e
- def query(self, mapper_or_class, *addtl_entities, **kwargs):
- """Return a new ``Query`` object corresponding to this ``Session`` and
- the mapper, or the classes' primary mapper.
-
- """
- entity_name = kwargs.pop('entity_name', None)
-
- if isinstance(mapper_or_class, type):
- q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
- else:
- q = self._query_cls(mapper_or_class, self, **kwargs)
-
- for ent in addtl_entities:
- q = q.add_entity(ent)
- return q
-
+ def query(self, *entities, **kwargs):
+ """Return a new ``Query`` object corresponding to this ``Session``."""
+
+ return self._query_cls(entities, self, **kwargs)
def _autoflush(self):
if self.autoflush and (self.transaction is None or self.transaction.autoflush):
self.flush()
+
+ def _finalize_loaded(self, states):
+ for state in states:
+ state.commit_all()
- def flush(self, objects=None):
- """Flush all the object modifications present in this session
- to the database.
-
- `objects` is a collection or iterator of objects specifically to be
- flushed; if ``None``, all new and modified objects are flushed.
-
- """
- if objects is not None:
- try:
- if not len(objects):
- return
- except TypeError:
- objects = list(objects)
- if not objects:
- return
- self.uow.flush(self, objects)
-
- def get(self, class_, ident, **kwargs):
+ def get(self, class_, ident, entity_name=None):
"""Return an instance of the object based on the given
identifier, or ``None`` if not found.
query.
"""
- entity_name = kwargs.pop('entity_name', None)
- return self.query(class_, entity_name=entity_name).get(ident, **kwargs)
+ return self.query(class_, entity_name=entity_name).get(ident)
- def load(self, class_, ident, **kwargs):
+ def load(self, class_, ident, entity_name=None):
"""Return an instance of the object based on the given
identifier.
query.
"""
- entity_name = kwargs.pop('entity_name', None)
- return self.query(class_, entity_name=entity_name).load(ident, **kwargs)
+ return self.query(class_, entity_name=entity_name).load(ident)
def refresh(self, instance, attribute_names=None):
"""Refresh the attributes on the given instance.
refreshed.
"""
- self._validate_persistent(instance)
+ state = attributes.instance_state(instance)
+ self._validate_persistent(state)
+ if self.query(_object_mapper(instance))._get(
+ state.key, refresh_instance=state,
+ only_load_props=attribute_names) is None:
+ raise sa_exc.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
- if self.query(_object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
- raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
-
def expire_all(self):
"""Expires all persistent instances within this Session.
of attribute names indicating a subset of attributes to be
expired.
"""
-
+ state = attributes.instance_state(instance)
+ self._validate_persistent(state)
if attribute_names:
- self._validate_persistent(instance)
- _expire_state(instance._state, attribute_names=attribute_names)
+ _expire_state(state, attribute_names=attribute_names)
else:
# pre-fetch the full cascade since the expire is going to
# remove associations
- cascaded = list(_cascade_iterator('refresh-expire', instance))
- self._validate_persistent(instance)
- _expire_state(instance._state, None)
- for (c, m) in cascaded:
- self._validate_persistent(c)
- _expire_state(c._state, None)
+ cascaded = list(_cascade_state_iterator('refresh-expire', state))
+ _expire_state(state, None)
+ for (state, m) in cascaded:
+ _expire_state(state, None)
def prune(self):
"""Remove unreferenced instances cached in the identity map.
Returns the number of objects pruned.
"""
- return self.uow.prune_identity_map()
+ return self.identity_map.prune()
def expunge(self, instance):
"""Remove the given `instance` from this ``Session``.
Cascading will be applied according to the *expunge* cascade
rule.
"""
- self._validate_persistent(instance)
- for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)):
- if c in self:
- self.uow._remove_deleted(c._state)
- self._unattach(c)
+
+ state = attributes.instance_state(instance)
+ if state.session_id is not self.hash_key:
+ raise sa_exc.InvalidRequestError("Instance %s is not present in this Session" % mapperutil.state_str(state))
+ for s, m in [(state, None)] + list(_cascade_state_iterator('expunge', state)):
+ self._expunge_state(s)
+
+ def _expunge_state(self, state):
+ if state in self._new:
+ self._new.pop(state)
+ del state.session_id
+ elif self.identity_map.contains_state(state):
+ self.identity_map.discard(state)
+ self._deleted.pop(state, None)
+ del state.session_id
+
+ def _register_newly_persistent(self, state):
+ mapper = _state_mapper(state)
+ instance_key = mapper._identity_key_from_state(state)
+
+ if state.key is None:
+ state.key = instance_key
+ elif state.key != instance_key:
+ # primary key switch
+ self.identity_map.remove(state)
+ state.key = instance_key
+
+ if hasattr(state, 'insert_order'):
+ delattr(state, 'insert_order')
+
+ obj = state.obj()
+ # prevent against last minute dereferences of the object
+ # TODO: identify a code path where state.obj() is None
+ if obj is not None:
+ if state.key in self.identity_map and not self.identity_map.contains_state(state):
+ self.identity_map.remove_key(state.key)
+ self.identity_map.add(state)
+ state.commit_all()
+
+ # remove from new last, might be the last strong ref
+ if state in self._new:
+ if self.transaction:
+ self.transaction._new[state] = True
+ self._new.pop(state)
+
+ def _remove_newly_deleted(self, state):
+ if self.transaction:
+ self.transaction._deleted[state] = True
+
+ self.identity_map.discard(state)
+ self._deleted.pop(state, None)
+ del state.session_id
def save(self, instance, entity_name=None):
"""Add a transient (unsaved) instance to this ``Session``.
The `entity_name` keyword argument will further qualify the
specific ``Mapper`` used to handle this instance.
+
"""
- self._save_impl(instance, entity_name=entity_name)
- self._cascade_save_or_update(instance)
-
+ state = _state_for_unsaved_instance(instance, entity_name)
+ self._save_impl(state)
+ self._cascade_save_or_update(state, entity_name)
+
+ # TODO
+ #save = util.deprecated("Use the add() method.")(save)
+
+ def _save_without_cascade(self, instance, entity_name=None):
+ """used by scoping.py to save on init without cascade."""
+
+ state = _state_for_unsaved_instance(instance, entity_name)
+ self._save_impl(state)
+
def update(self, instance, entity_name=None):
"""Bring the given detached (saved) instance into this
``Session``.
This operation cascades the `save_or_update` method to
associated instances if the relation is mapped with
``cascade="save-update"``.
+
"""
+ state = attributes.instance_state(instance)
+ self._update_impl(state)
+ self._cascade_save_or_update(state, entity_name)
+
+ # TODO
+ #update = util.deprecated("Use the add() method.")(update)
+
+ def add(self, instance, entity_name=None):
+ """Add the given instance into this ``Session``.
- self._update_impl(instance, entity_name=entity_name)
- self._cascade_save_or_update(instance)
-
- def save_or_update(self, instance, entity_name=None):
- """Save or update the given instance into this ``Session``.
+ The non-None state `key` on the instance's state determines whether
+ to ``save()`` or ``update()`` the instance.
- The presence of an `_instance_key` attribute on the instance
- determines whether to ``save()`` or ``update()`` the instance.
"""
-
- self._save_or_update_impl(instance, entity_name=entity_name)
- self._cascade_save_or_update(instance)
-
- def _cascade_save_or_update(self, instance):
- for obj, mapper in _cascade_iterator('save-update', instance, halt_on=lambda c:c in self):
- self._save_or_update_impl(obj, mapper.entity_name)
+ state = _state_for_unknown_persistence_instance(instance, entity_name)
+ self._save_or_update_state(state, entity_name)
+
+ def add_all(self, instances):
+ """Add the given collection of instances to this ``Session``."""
+
+ for instance in instances:
+ self.add(instance)
+
+ # TODO
+ # save_or_update = util.deprecated("Use the add() method.")(add)
+ save_or_update = add
+
+ def _save_or_update_state(self, state, entity_name):
+ self._save_or_update_impl(state)
+ self._cascade_save_or_update(state, entity_name)
+
+ def _cascade_save_or_update(self, state, entity_name):
+ for state, mapper in _cascade_unknown_state_iterator('save-update', state, halt_on=lambda c:c in self):
+ self._save_or_update_impl(state)
def delete(self, instance):
"""Mark the given instance as deleted.
The delete operation occurs upon ``flush()``.
"""
- self._delete_impl(instance)
- for c, m in _cascade_iterator('delete', instance):
- self._delete_impl(c, ignore_transient=True)
+ state = attributes.instance_state(instance)
+ self._delete_impl(state)
+ for state, m in _cascade_state_iterator('delete', state):
+ self._delete_impl(state, ignore_transient=True)
def merge(self, instance, entity_name=None, dont_load=False, _recursive=None):
if instance in _recursive:
return _recursive[instance]
- key = getattr(instance, '_instance_key', None)
+ new_instance = False
+ state = attributes.instance_state(instance)
+ key = state.key
if key is None:
if dont_load:
- raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects. flush() all changes on mapped instances before merging with dont_load=True.")
- key = mapper.identity_key_from_instance(instance)
+ raise sa_exc.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects. flush() all changes on mapped instances before merging with dont_load=True.")
+ key = mapper._identity_key_from_state(state)
merged = None
if key:
if key in self.identity_map:
merged = self.identity_map[key]
elif dont_load:
- if instance._state.modified:
- raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True.")
-
- merged = attributes.new_instance(mapper.class_)
- merged._instance_key = key
- merged._entity_name = entity_name
- self._update_impl(merged, entity_name=mapper.entity_name)
+ if state.modified:
+ raise sa_exc.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True.")
+
+ merged = mapper.class_manager.new_instance()
+ merged_state = attributes.instance_state(merged)
+ merged_state.key = key
+ merged_state.entity_name = entity_name
+ self._update_impl(merged_state)
+ new_instance = True
else:
merged = self.get(mapper.class_, key[1])
-
+
if merged is None:
- merged = attributes.new_instance(mapper.class_)
+ merged = mapper.class_manager.new_instance()
+ merged_state = attributes.instance_state(merged)
+ new_instance = True
self.save(merged, entity_name=mapper.entity_name)
-
+
_recursive[instance] = merged
-
+
for prop in mapper.iterate_properties:
prop.merge(self, instance, merged, dont_load, _recursive)
-
+
if dont_load:
- merged._state.commit_all() # remove any history
+ attributes.instance_state(merged).commit_all() # remove any history
+ if new_instance:
+ merged_state._run_on_load(merged)
return merged
def identity_key(cls, *args, **kwargs):
- """Get an identity key.
-
- Valid call signatures:
-
- * ``identity_key(class, ident, entity_name=None)``
-
- class
- mapped class (must be a positional argument)
-
- ident
- primary key, if the key is composite this is a tuple
-
- entity_name
- optional entity name
-
- * ``identity_key(instance=instance)``
-
- instance
- object instance (must be given as a keyword arg)
-
- * ``identity_key(class, row=row, entity_name=None)``
-
- class
- mapped class (must be a positional argument)
-
- row
- result proxy row (must be given as a keyword arg)
-
- entity_name
- optional entity name (must be given as a keyword arg)
- """
-
- if args:
- if len(args) == 1:
- class_ = args[0]
- try:
- row = kwargs.pop("row")
- except KeyError:
- ident = kwargs.pop("ident")
- entity_name = kwargs.pop("entity_name", None)
- elif len(args) == 2:
- class_, ident = args
- entity_name = kwargs.pop("entity_name", None)
- elif len(args) == 3:
- class_, ident, entity_name = args
- else:
- raise exceptions.ArgumentError("expected up to three "
- "positional arguments, got %s" % len(args))
- if kwargs:
- raise exceptions.ArgumentError("unknown keyword arguments: %s"
- % ", ".join(kwargs.keys()))
- mapper = _class_mapper(class_, entity_name=entity_name)
- if "ident" in locals():
- return mapper.identity_key_from_primary_key(ident)
- return mapper.identity_key_from_row(row)
- instance = kwargs.pop("instance")
- if kwargs:
- raise exceptions.ArgumentError("unknown keyword arguments: %s"
- % ", ".join(kwargs.keys()))
- mapper = _object_mapper(instance)
- return mapper.identity_key_from_instance(instance)
+ return mapperutil.identity_key(*args, **kwargs)
identity_key = classmethod(identity_key)
def object_session(cls, instance):
return object_session(instance)
object_session = classmethod(object_session)
- def _save_impl(self, instance, **kwargs):
- if hasattr(instance, '_instance_key'):
- raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(instance))
- else:
- # TODO: consolidate the steps here
- attributes.manage(instance)
- instance._entity_name = kwargs.get('entity_name', None)
- self._attach(instance)
- self.uow.register_new(instance)
-
- def _update_impl(self, instance, **kwargs):
- if instance in self and instance not in self.deleted:
+ def _validate_persistent(self, state):
+ if not self.identity_map.contains_state(state):
+ raise sa_exc.InvalidRequestError("Instance '%s' is not persistent within this Session" % mapperutil.state_str(state))
+
+ def _save_impl(self, state):
+ if state.key is not None:
+ raise sa_exc.InvalidRequestError(
+ "Object '%s' already has an identity - it can't be registered "
+ "as pending" % repr(obj))
+ self._attach(state)
+ if state not in self._new:
+ self._new[state] = state.obj()
+ state.insert_order = len(self._new)
+
+ def _update_impl(self, state):
+ if self.identity_map.contains_state(state) and state not in self._deleted:
return
- if not hasattr(instance, '_instance_key'):
- raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
- elif self.identity_map.get(instance._instance_key, instance) is not instance:
- raise exceptions.InvalidRequestError("Could not update instance '%s', identity key %s; a different instance with the same identity key already exists in this session." % (mapperutil.instance_str(instance), instance._instance_key))
- self._attach(instance)
-
- def _save_or_update_impl(self, instance, entity_name=None):
- key = getattr(instance, '_instance_key', None)
- if key is None:
- self._save_impl(instance, entity_name=entity_name)
+
+ if state.key is None:
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is not persisted" %
+ mapperutil.state_str(state))
+
+ if state.key in self.identity_map and not self.identity_map.contains_state(state):
+ raise sa_exc.InvalidRequestError(
+ "Could not update instance '%s', identity key %s; a different "
+ "instance with the same identity key already exists in this "
+ "session." % (mapperutil.state_str(state), state.key))
+
+ self._attach(state)
+ self._deleted.pop(state, None)
+ self.identity_map.add(state)
+
+ def _save_or_update_impl(self, state):
+ if state.key is None:
+ self._save_impl(state)
else:
- self._update_impl(instance, entity_name=entity_name)
+ self._update_impl(state)
- def _delete_impl(self, instance, ignore_transient=False):
- if instance in self and instance in self.deleted:
+ def _delete_impl(self, state, ignore_transient=False):
+ if self.identity_map.contains_state(state) and state in self._deleted:
return
- if not hasattr(instance, '_instance_key'):
+
+ if state.key is None:
if ignore_transient:
return
else:
- raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
- if self.identity_map.get(instance._instance_key, instance) is not instance:
- raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(instance), instance._instance_key))
- self._attach(instance)
- self.uow.register_deleted(instance)
-
- def _attach(self, instance):
- old_id = getattr(instance, '_sa_session_id', None)
- if old_id != self.hash_key:
- if old_id is not None and old_id in _sessions and instance in _sessions[old_id]:
- raise exceptions.InvalidRequestError("Object '%s' is already attached "
- "to session '%s' (this is '%s')" %
- (mapperutil.instance_str(instance), old_id, id(self)))
-
- key = getattr(instance, '_instance_key', None)
- if key is not None:
- self.identity_map[key] = instance
- instance._sa_session_id = self.hash_key
-
- def _unattach(self, instance):
- if instance._sa_session_id == self.hash_key:
- del instance._sa_session_id
-
- def _validate_persistent(self, instance):
- """Validate that the given instance is persistent within this
- ``Session``.
- """
-
- if instance not in self:
- raise exceptions.InvalidRequestError("Instance '%s' is not persistent within this Session" % mapperutil.instance_str(instance))
+ raise sa_exc.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.state_str(state))
+ if state.key in self.identity_map and not self.identity_map.contains_state(state):
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is with key %s already persisted with a "
+ "different identity" % (mapperutil.state_str(state),
+ state.key))
+
+ self._deleted[state] = state.obj()
+ self._attach(state)
+
+ def _attach(self, state):
+ if state.session_id and state.session_id is not self.hash_key:
+ raise sa_exc.InvalidRequestError(
+ "Object '%s' is already attached to session '%s' "
+ "(this is '%s')" % (mapperutil.state_str(state),
+ state.session_id, self.hash_key))
+ if state.session_id != self.hash_key:
+ state.session_id = self.hash_key
def __contains__(self, instance):
"""Return True if the given instance is associated with this session.
The instance may be pending or persistent within the Session for a
result of True.
- """
-
- return instance._state in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance)
+ """
+ return self._contains_state(attributes.instance_state(instance))
+
def __iter__(self):
"""Return an iterator of all instances which are pending or persistent within this Session."""
- return iter(list(self.uow.new.values()) + self.uow.identity_map.values())
+ return iter(list(self._new.values()) + self.identity_map.values())
+
+ def _contains_state(self, state):
+ return state in self._new or self.identity_map.contains_state(state)
+
+
+ def flush(self, objects=None):
+ """Flush all the object modifications present in this session
+ to the database.
+
+ `objects` is a list or tuple of objects specifically to be
+ flushed; if ``None``, all new and modified objects are flushed.
+
+ """
+ if not self.identity_map.check_modified() and not self._deleted and not self._new:
+ return
+
+ dirty = self._dirty_states
+ if not dirty and not self._deleted and not self._new:
+ self.identity_map.modified = False
+ return
+
+ deleted = util.Set(self._deleted)
+ new = util.Set(self._new)
+
+ dirty = util.Set(dirty).difference(deleted)
+
+ flush_context = UOWTransaction(self)
+
+ if self.extension is not None:
+ self.extension.before_flush(self, flush_context, objects)
+
+ # create the set of all objects we want to operate upon
+ if objects:
+ # specific list passed in
+ objset = util.Set([attributes.instance_state(o) for o in objects])
+ else:
+ # or just everything
+ objset = util.Set(self.identity_map.all_states()).union(new)
+
+ # store objects whose fate has been decided
+ processed = util.Set()
+
+ # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted.
+ for state in new.union(dirty).intersection(objset).difference(deleted):
+ is_orphan = _state_mapper(state)._is_orphan(state)
+ if is_orphan and not _state_has_identity(state):
+ raise exc.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" %
+ (
+ mapperutil.state_str(state),
+ ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans])
+ ))
+ flush_context.register_object(state, isdelete=is_orphan)
+ processed.add(state)
+
+ # put all remaining deletes into the flush context.
+ for state in deleted.intersection(objset).difference(processed):
+ flush_context.register_object(state, isdelete=True)
+
+ if len(flush_context.tasks) == 0:
+ return
+
+ flush_context.transaction = transaction = self.begin(subtransactions=True, _autoflush=False)
+ try:
+ flush_context.execute()
+
+ if self.extension is not None:
+ self.extension.after_flush(self, flush_context)
+ transaction.commit()
+ except:
+ transaction.rollback()
+ raise
+
+ flush_context.finalize_flush_changes()
+
+ if not objects:
+ self.identity_map.modified = False
+
+ if self.extension is not None:
+ self.extension.after_flush_postexec(self, flush_context)
def is_modified(self, instance, include_collections=True, passive=False):
"""Return True if the given instance has modified attributes.
not be loaded in the course of performing this test.
"""
- for attr in attributes._managed_attributes(instance.__class__):
+ for attr in attributes.manager_of_class(instance.__class__).attributes:
if not include_collections and hasattr(attr.impl, 'get_collection'):
continue
(added, unchanged, deleted) = attr.get_history(instance)
return True
return False
+ def _dirty_states(self):
+ """Return a set of all persistent states considered dirty.
+
+ This method returns all states that were modified including those that
+ were possibly deleted.
+
+ """
+ return util.IdentitySet(
+ [state for state in self.identity_map.all_states() if state.check_modified()]
+ )
+ _dirty_states = property(_dirty_states)
+
def dirty(self):
- """Return a ``Set`` of all instances marked as 'dirty' within this ``Session``.
+ """Return a set of all persistent instances considered dirty.
+
+ Instances are considered dirty when they were modified but not
+ deleted.
Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection
modification operations will mark an instance as 'dirty' and place it in this set,
To check if an instance has actionable net changes to its attributes, use the
is_modified() method.
+
"""
+
+ return util.IdentitySet(
+ [state.obj() for state in self._dirty_states if state not in self._deleted]
+ )
- return self.uow.locate_dirty()
dirty = property(dirty)
def deleted(self):
"Return a ``Set`` of all instances marked as 'deleted' within this ``Session``"
- return util.IdentitySet(self.uow.deleted.values())
+ return util.IdentitySet(self._deleted.values())
deleted = property(deleted)
def new(self):
"Return a ``Set`` of all instances marked as 'new' within this ``Session``."
- return util.IdentitySet(self.uow.new.values())
+ return util.IdentitySet(self._new.values())
new = property(new)
def _expire_state(state, attribute_names):
_sessions = weakref.WeakValueDictionary()
-def _cascade_iterator(cascade, instance, **kwargs):
- mapper = _object_mapper(instance)
- for (o, m) in mapper.cascade_iterator(cascade, instance._state, **kwargs):
- yield o, m
+def _cascade_state_iterator(cascade, state, **kwargs):
+ mapper = _state_mapper(state)
+ for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs):
+ yield attributes.instance_state(o), m
+
+def _cascade_unknown_state_iterator(cascade, state, **kwargs):
+ mapper = _state_mapper(state)
+ for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs):
+ yield _state_for_unknown_persistence_instance(o, m.entity_name), m
+
+def _state_for_unsaved_instance(instance, entity_name):
+ manager = attributes.manager_of_class(instance.__class__)
+ if manager is None:
+ raise "FIXME unmapped instance"
+ if manager.has_state(instance):
+ state = manager.state_of(instance)
+ if state.key is not None:
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is already persistent" %
+ mapperutil.state_str(state))
+ else:
+ state = manager.setup_instance(instance)
+ state.entity_name = entity_name
+ return state
+
+def _state_for_unknown_persistence_instance(instance, entity_name):
+ state = attributes.instance_state(instance)
+ state.entity_name = entity_name
+ return state
def object_session(instance):
"""Return the ``Session`` to which the given instance is bound, or ``None`` if none."""
- hashkey = getattr(instance, '_sa_session_id', None)
- if hashkey is not None:
- sess = _sessions.get(hashkey)
- if sess is not None and instance in sess:
- return sess
+ return _state_session(attributes.instance_state(instance))
+
+def _state_session(state):
+ if state.session_id:
+ try:
+ return _sessions[state.session_id]
+ except KeyError:
+ pass
return None
# Lazy initialization to avoid circular imports
unitofwork.object_session = object_session
+unitofwork._state_session = _state_session
from sqlalchemy.orm import mapper
mapper._expire_state = _expire_state
+mapper._state_session = _state_session
-"""Defines a rudimental 'horizontal sharding' system which allows a
-Session to distribute queries and persistence operations across multiple
-databases.
+# shard.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
-For a usage example, see the example ``examples/sharding/attribute_shard.py``.
+"""Horizontal sharding support.
+
+Defines a rudimental 'horizontal sharding' system which allows a Session to
+distribute queries and persistence operations across multiple databases.
+
+For a usage example, see the file ``examples/sharding/attribute_shard.py``
+included in the source distrbution.
"""
+
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import util
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.query import Query
-from sqlalchemy import exceptions, util
__all__ = ['ShardedSession', 'ShardedQuery']
+
class ShardedSession(Session):
def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
- """construct a ShardedSession.
-
- shard_chooser
- a callable which, passed a Mapper, a mapped instance, and possibly a
- SQL clause, returns a shard ID. this id may be based off of the
- attributes present within the object, or on some round-robin scheme. If
- the scheme is based on a selection, it should set whatever state on the
- instance to mark it in the future as participating in that shard.
-
- id_chooser
- a callable, passed a query and a tuple of identity values,
- which should return a list of shard ids where the ID might
- reside. The databases will be queried in the order of this
- listing.
-
- query_chooser
- for a given Query, returns the list of shard_ids where the query
- should be issued. Results from all shards returned will be
- combined together into a single listing.
-
+ """Construct a ShardedSession.
+
+ shard_chooser
+ A callable which, passed a Mapper, a mapped instance, and possibly a
+ SQL clause, returns a shard ID. This id may be based off of the
+ attributes present within the object, or on some round-robin
+ scheme. If the scheme is based on a selection, it should set
+ whatever state on the instance to mark it in the future as
+ participating in that shard.
+
+ id_chooser
+ A callable, passed a query and a tuple of identity values, which
+ should return a list of shard ids where the ID might reside. The
+ databases will be queried in the order of this listing.
+
+ query_chooser
+ For a given Query, returns the list of shard_ids where the query
+ should be issued. Results from all shards returned will be combined
+ together into a single listing.
+
"""
super(ShardedSession, self).__init__(**kwargs)
self.shard_chooser = shard_chooser
def _execute_and_instances(self, context):
if self._shard_id is not None:
- result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(context.statement, **self._params)
+ result = self.session.connection(mapper=self._mapper_zero(), shard_id=self._shard_id).execute(context.statement, **self._params)
try:
- return iter(self.instances(result, querycontext=context))
+ return iter(self.instances(result, context))
finally:
result.close()
else:
partial = []
for shard_id in self.query_chooser(self):
- result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(context.statement, **self._params)
+ result = self.session.connection(mapper=self._mapper_zero(), shard_id=shard_id).execute(context.statement, **self._params)
try:
- partial = partial + list(self.instances(result, querycontext=context))
+ partial = partial + list(self.instances(result, context))
finally:
result.close()
# if some kind of in memory 'sorting' were done, this is where it would happen
if o is not None:
return o
else:
- raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
+ raise sa_exc.InvalidRequestError("No instance found for identity %s" % repr(ident))
"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
-from sqlalchemy import sql, util, exceptions, logging
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import sql, util, log
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import visitors, expression, operators
from sqlalchemy.orm import mapper, attributes
-from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, serialize_path, deserialize_path
+from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, \
+ MapperOption, PropertyOption, serialize_path, deserialize_path
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
def init(self):
super(ColumnLoader, self).init()
self.columns = self.parent_property.columns
- self._should_log_debug = logging.is_debug_enabled(self.logger)
+ self._should_log_debug = log.is_debug_enabled(self.logger)
self.is_composite = hasattr(self.parent_property, 'composite_class')
- def setup_query(self, context, parentclauses=None, **kwargs):
+ def setup_query(self, context, entity, path, adapter, column_collection=None, **kwargs):
for c in self.columns:
- if parentclauses is not None:
- context.secondary_columns.append(parentclauses.aliased_column(c))
- else:
- context.primary_columns.append(c)
+ if adapter:
+ c = adapter.columns[c]
+ column_collection.append(c)
def init_class_attribute(self):
self.is_class_level = True
- if self.is_composite:
- self._init_composite_attribute()
+ self.logger.info("%s register managed attribute" % self)
+ coltype = self.columns[0].type
+ sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
+
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ key, col = self.key, self.columns[0]
+ if adapter:
+ col = adapter.columns[col]
+ if col in row:
+ def new_execute(state, row, **flags):
+ state.dict[key] = row[col]
+
+ if self._should_log_debug:
+ new_execute = self.debug_callable(new_execute, self.logger,
+ "%s returning active column fetcher" % self,
+ lambda state, row, **flags: "%s populating %s" % (self, mapperutil.state_attribute_str(state, key))
+ )
+ return (new_execute, None)
else:
- self._init_scalar_attribute()
+ def new_execute(state, row, isnew, **flags):
+ if isnew:
+ state.expire_attributes([key])
+ if self._should_log_debug:
+ self.logger.debug("%s deferring load" % self)
+ return (new_execute, None)
+
+ColumnLoader.logger = log.class_logger(ColumnLoader)
+
+class CompositeColumnLoader(ColumnLoader):
+ def init_class_attribute(self):
+ self.is_class_level = True
+ self.logger.info("%s register managed composite attribute" % self)
- def _init_composite_attribute(self):
- self.logger.info("register managed composite attribute %s on class %s" % (self.key, self.parent.class_.__name__))
def copy(obj):
- return self.parent_property.composite_class(
- *obj.__composite_values__())
+ return self.parent_property.composite_class(*obj.__composite_values__())
+
def compare(a, b):
for col, aprop, bprop in zip(self.columns,
a.__composite_values__(),
return False
else:
return True
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator)
-
- def _init_scalar_attribute(self):
- self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
- coltype = self.columns[0].type
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
-
- def create_row_processor(self, selectcontext, mapper, row):
- if self.is_composite:
- for c in self.columns:
- if c not in row:
- break
- else:
- def new_execute(instance, row, **flags):
- if self._should_log_debug:
- self.logger.debug("populating %s with %s/%s..." % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
- instance.__dict__[self.key] = self.parent_property.composite_class(*[row[c] for c in self.columns])
- if self._should_log_debug:
- self.logger.debug("Returning active composite column fetcher for %s %s" % (mapper, self.key))
- return (new_execute, None, None)
-
- elif self.columns[0] in row:
- def new_execute(instance, row, **flags):
+ sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent)
+
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class
+ if adapter:
+ columns = [adapter.columns[c] for c in columns]
+ for c in columns:
+ if c not in row:
+ def new_execute(state, row, isnew, **flags):
+ if isnew:
+ state.expire_attributes([key])
if self._should_log_debug:
- self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
- instance.__dict__[self.key] = row[self.columns[0]]
- if self._should_log_debug:
- self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
- return (new_execute, None, None)
+ self.logger.debug("%s deferring load" % self)
+ return (new_execute, None)
else:
- def new_execute(instance, row, isnew, **flags):
- if isnew:
- instance._state.expire_attributes([self.key])
+ def new_execute(state, row, **flags):
+ state.dict[key] = composite_class(*[row[c] for c in columns])
+
if self._should_log_debug:
- self.logger.debug("Deferring load for %s %s" % (mapper, self.key))
- return (new_execute, None, None)
+ new_execute = self.debug_callable(new_execute, self.logger,
+ "%s returning active composite column fetcher" % self,
+ lambda state, row, **flags: "populating %s" % (mapperutil.state_attribute_str(state, key))
+ )
-ColumnLoader.logger = logging.class_logger(ColumnLoader)
+ return (new_execute, None)
+CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader)
+
class DeferredColumnLoader(LoaderStrategy):
"""Deferred column loader, a per-column or per-column-group lazy loader."""
- def create_row_processor(self, selectcontext, mapper, row):
- if self.columns[0] in row:
- return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row)
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ col = self.columns[0]
+ if adapter:
+ col = adapter.columns[col]
+ if col in row:
+ return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter)
+
elif not self.is_class_level or len(selectcontext.options):
- def new_execute(instance, row, **flags):
- if self._should_log_debug:
- self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
- instance._state.set_callable(self.key, self.setup_loader(instance))
- return (new_execute, None, None)
+ def new_execute(state, row, **flags):
+ state.set_callable(self.key, self.setup_loader(state))
else:
- def new_execute(instance, row, **flags):
- if self._should_log_debug:
- self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
- instance._state.reset(self.key)
- return (new_execute, None, None)
+ def new_execute(state, row, **flags):
+ state.reset(self.key)
+
+ if self._should_log_debug:
+ new_execute = self.debug_callable(new_execute, self.logger, None,
+ lambda state, row, **flags: "set deferred callable on %s" % mapperutil.state_attribute_str(state, self.key)
+ )
+ return (new_execute, None)
def init(self):
super(DeferredColumnLoader, self).init()
raise NotImplementedError("Deferred loading for composite types not implemented yet")
self.columns = self.parent_property.columns
self.group = self.parent_property.group
- self._should_log_debug = logging.is_debug_enabled(self.logger)
+ self._should_log_debug = log.is_debug_enabled(self.logger)
def init_class_attribute(self):
self.is_class_level = True
- self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
+ self.logger.info("%s register managed attribute" % self)
+ sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
- def setup_query(self, context, only_load_props=None, **kwargs):
+ def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
if \
(self.group is not None and context.attributes.get(('undefer', self.group), False)) or \
(only_load_props and self.key in only_load_props):
- self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
+ self.parent_property._get_strategy(ColumnLoader).setup_query(context, entity, path, adapter, **kwargs)
- def class_level_loader(self, instance, props=None):
- if not mapper.has_mapper(instance):
+ def class_level_loader(self, state, props=None):
+ if not mapperutil._state_has_mapper(state):
return None
- localparent = mapper.object_mapper(instance)
+ localparent = mapper._state_mapper(state)
# adjust for the ColumnProperty associated with the instance
# not being our own ColumnProperty. This can occur when entity_name
# to the class.
prop = localparent.get_property(self.key)
if prop is not self.parent_property:
- return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
+ return prop._get_strategy(DeferredColumnLoader).setup_loader(state)
- return LoadDeferredColumns(instance, self.key, props)
+ return LoadDeferredColumns(state, self.key, props)
- def setup_loader(self, instance, props=None, create_statement=None):
- return LoadDeferredColumns(instance, self.key, props, optimizing_statement=create_statement)
+ def setup_loader(self, state, props=None, create_statement=None):
+ return LoadDeferredColumns(state, self.key, props)
-DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
+DeferredColumnLoader.logger = log.class_logger(DeferredColumnLoader)
class LoadDeferredColumns(object):
- """callable, serializable loader object used by DeferredColumnLoader"""
+ """serializable loader object used by DeferredColumnLoader"""
- def __init__(self, instance, key, keys, optimizing_statement=None):
- self.instance = instance
+ def __init__(self, state, key, keys):
+ self.state = state
self.key = key
self.keys = keys
- self.optimizing_statement = optimizing_statement
def __getstate__(self):
- return {'instance':self.instance, 'key':self.key, 'keys':self.keys}
+ return {'state':self.state, 'key':self.key, 'keys':self.keys}
def __setstate__(self, state):
- self.instance = state['instance']
+ self.state = state['state']
self.key = state['key']
self.keys = state['keys']
- self.optimizing_statement = None
def __call__(self):
- if not mapper.has_identity(self.instance):
+ state = self.state
+
+ if not mapper._state_has_identity(state):
return None
-
- localparent = mapper.object_mapper(self.instance, raiseerror=False)
+
+ localparent = mapper._state_mapper(state)
prop = localparent.get_property(self.key)
strategy = prop._get_strategy(DeferredColumnLoader)
toload = [self.key]
# narrow the keys down to just those which have no history
- group = [k for k in toload if k in self.instance._state.unmodified]
+ group = [k for k in toload if k in state.unmodified]
if strategy._should_log_debug:
- strategy.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(self.instance, self.key), group and ','.join(group) or 'None'))
+ strategy.logger.debug("deferred load %s group %s" % (mapperutil.state_attribute_str(state, self.key), group and ','.join(group) or 'None'))
- session = sessionlib.object_session(self.instance)
+ session = sessionlib._state_session(state)
if session is None:
- raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key))
+ raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (mapperutil.state_str(state), self.key))
query = session.query(localparent)
- if not self.optimizing_statement:
- ident = self.instance._instance_key[1]
- query._get(None, ident=ident, only_load_props=group, refresh_instance=self.instance._state)
- else:
- statement, params = self.optimizing_statement(self.instance)
- query.from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=self.instance._state)
+ ident = state.key[1]
+ query._get(None, ident=ident, only_load_props=group, refresh_instance=state)
return attributes.ATTR_WAS_SET
class DeferredOption(StrategizedOption):
class AbstractRelationLoader(LoaderStrategy):
def init(self):
super(AbstractRelationLoader, self).init()
- for attr in ['primaryjoin', 'secondaryjoin', 'secondary', 'foreign_keys', 'mapper', 'target', 'table', 'uselist', 'cascade', 'attributeext', 'order_by', 'remote_side', 'direction']:
+ for attr in ['mapper', 'target', 'table', 'uselist']:
setattr(self, attr, getattr(self.parent_property, attr))
- self._should_log_debug = logging.is_debug_enabled(self.logger)
+ self._should_log_debug = log.is_debug_enabled(self.logger)
- def _init_instance_attribute(self, instance, callable_=None):
+ def _init_instance_attribute(self, state, callable_=None):
if callable_:
- instance._state.set_callable(self.key, callable_)
+ state.set_callable(self.key, callable_)
else:
- instance._state.initialize(self.key)
+ state.initialize(self.key)
def _register_attribute(self, class_, callable_=None, **kwargs):
- self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
- sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, **kwargs)
+ self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar")))
+
+ if self.parent_property.backref:
+ attribute_ext = self.parent_property.backref.extension
+ else:
+ attribute_ext = None
+
+ sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs)
class NoLoader(AbstractRelationLoader):
def init_class_attribute(self):
self.is_class_level = True
self._register_attribute(self.parent.class_)
- def create_row_processor(self, selectcontext, mapper, row):
- def new_execute(instance, row, ispostselect, **flags):
- if not ispostselect:
- if self._should_log_debug:
- self.logger.debug("initializing blank scalar/collection on %s" % mapperutil.attribute_str(instance, self.key))
- self._init_instance_attribute(instance)
- return (new_execute, None, None)
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ def new_execute(state, row, **flags):
+ self._init_instance_attribute(state)
+
+ if self._should_log_debug:
+ new_execute = self.debug_callable(new_execute, self.logger, None,
+ lambda state, row, **flags: "initializing blank scalar/collection on %s" % mapperutil.state_attribute_str(state, self.key)
+ )
+ return (new_execute, None)
-NoLoader.logger = logging.class_logger(NoLoader)
+NoLoader.logger = log.class_logger(NoLoader)
class LazyLoader(AbstractRelationLoader):
def init(self):
super(LazyLoader, self).init()
(self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property)
- self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.__lazywhere))
+ self.logger.info("%s lazy loading clause %s" % (self, self.__lazywhere))
# determine if our "lazywhere" clause is the same as the mapper's
# get() clause. then we can just use mapper.get()
#from sqlalchemy.orm import query
self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere)
if self.use_get:
- self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
+ self.logger.info("%s will use query.get() to optimize instance loads" % self)
def init_class_attribute(self):
self.is_class_level = True
self._register_attribute(self.parent.class_, callable_=self.class_level_loader)
- def lazy_clause(self, instance, reverse_direction=False):
- if instance is None:
+ def lazy_clause(self, state, reverse_direction=False):
+ if state is None:
return self._lazy_none_clause(reverse_direction)
if not reverse_direction:
# use the "committed" (database) version to get query column values
# also its a deferred value; so that when used by Query, the committed value is used
# after an autoflush occurs
- bindparam.value = lambda: mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key])
- return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam)
+ bindparam.value = lambda: mapper._get_committed_state_attr_by_column(state, bind_to_col[bindparam.key])
+ return visitors.cloned_traverse(criterion, {}, {'bindparam':visit_bindparam})
def _lazy_none_clause(self, reverse_direction=False):
if not reverse_direction:
binary.right = expression.null()
binary.operator = operators.is_
- return visitors.traverse(criterion, clone=True, visit_binary=visit_binary)
+ return visitors.cloned_traverse(criterion, {}, {'binary':visit_binary})
- def class_level_loader(self, instance, options=None, path=None):
- if not mapper.has_mapper(instance):
+ def class_level_loader(self, state, options=None, path=None):
+ if not mapperutil._state_has_mapper(state):
return None
- localparent = mapper.object_mapper(instance)
+ localparent = mapper._state_mapper(state)
# adjust for the PropertyLoader associated with the instance
# not being our own PropertyLoader. This can occur when entity_name
# to the class.
prop = localparent.get_property(self.key)
if prop is not self.parent_property:
- return prop._get_strategy(LazyLoader).setup_loader(instance)
+ return prop._get_strategy(LazyLoader).setup_loader(state)
- return LoadLazyAttribute(instance, self.key, options, path)
+ return LoadLazyAttribute(state, self.key, options, path)
- def setup_loader(self, instance, options=None, path=None):
- return LoadLazyAttribute(instance, self.key, options, path)
+ def setup_loader(self, state, options=None, path=None):
+ return LoadLazyAttribute(state, self.key, options, path)
- def create_row_processor(self, selectcontext, mapper, row):
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
if not self.is_class_level or len(selectcontext.options):
- def new_execute(instance, row, ispostselect, **flags):
- if not ispostselect:
- if self._should_log_debug:
- self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
- # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
- # which will override the class-level behavior
-
- self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options, selectcontext.query._current_path + selectcontext.path))
- return (new_execute, None, None)
+ path = path + (self.key,)
+ def new_execute(state, row, **flags):
+ # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
+ # which will override the class-level behavior
+ self._init_instance_attribute(state, callable_=self.setup_loader(state, selectcontext.options, selectcontext.query._current_path + path))
+
+ if self._should_log_debug:
+ new_execute = self.debug_callable(new_execute, self.logger, None,
+ lambda state, row, **flags: "set instance-level lazy loader on %s" % mapperutil.state_attribute_str(state, self.key)
+ )
+
+ return (new_execute, None)
else:
- def new_execute(instance, row, ispostselect, **flags):
- if not ispostselect:
- if self._should_log_debug:
- self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
- # we are the primary manager for this attribute on this class - reset its per-instance attribute state,
- # so that the class-level lazy loader is executed when next referenced on this instance.
- # this usually is not needed unless the constructor of the object referenced the attribute before we got
- # to load data into it.
- instance._state.reset(self.key)
- return (new_execute, None, None)
+ def new_execute(state, row, **flags):
+ # we are the primary manager for this attribute on this class - reset its per-instance attribute state,
+ # so that the class-level lazy loader is executed when next referenced on this instance.
+ # this usually is not needed unless the constructor of the object referenced the attribute before we got
+ # to load data into it.
+ state.reset(self.key)
+
+ if self._should_log_debug:
+ new_execute = self.debug_callable(new_execute, self.logger, None,
+ lambda state, row, **flags: "set class-level lazy loader on %s" % mapperutil.state_attribute_str(state, self.key)
+ )
+
+ return (new_execute, None)
def __create_lazy_clause(cls, prop, reverse_direction=False):
binds = {}
binds[col] = sql.bindparam(None, None, type_=col.type)
return binds[col]
return None
-
- lazywhere = prop.primaryjoin
+ lazywhere = prop.primaryjoin
+
if not prop.secondaryjoin or not reverse_direction:
- lazywhere = visitors.traverse(lazywhere, before_clone=col_to_bind, clone=True)
+ lazywhere = visitors.replacement_traverse(lazywhere, {}, col_to_bind)
if prop.secondaryjoin is not None:
secondaryjoin = prop.secondaryjoin
if reverse_direction:
- secondaryjoin = visitors.traverse(secondaryjoin, before_clone=col_to_bind, clone=True)
+ secondaryjoin = visitors.replacement_traverse(secondaryjoin, {}, col_to_bind)
lazywhere = sql.and_(lazywhere, secondaryjoin)
bind_to_col = dict([(binds[col].key, col) for col in binds])
return (lazywhere, bind_to_col, equated_columns)
__create_lazy_clause = classmethod(__create_lazy_clause)
-LazyLoader.logger = logging.class_logger(LazyLoader)
+LazyLoader.logger = log.class_logger(LazyLoader)
class LoadLazyAttribute(object):
- """callable, serializable loader object used by LazyLoader"""
+ """serializable loader object used by LazyLoader"""
- def __init__(self, instance, key, options, path):
- self.instance = instance
+ def __init__(self, state, key, options, path):
+ self.state = state
self.key = key
self.options = options
self.path = path
def __getstate__(self):
- return {'instance':self.instance, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
+ return {'state':self.state, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
def __setstate__(self, state):
- self.instance = state['instance']
+ self.state = state['state']
self.key = state['key']
- self.options= state['options']
+ self.options = state['options']
self.path = deserialize_path(state['path'])
def __call__(self):
- instance = self.instance
-
- if not mapper.has_identity(instance):
+ state = self.state
+ if not mapper._state_has_identity(state):
return None
- instance_mapper = mapper.object_mapper(instance)
+ instance_mapper = mapper._state_mapper(state)
prop = instance_mapper.get_property(self.key)
strategy = prop._get_strategy(LazyLoader)
if strategy._should_log_debug:
- strategy.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
+ strategy.logger.debug("loading %s" % mapperutil.state_attribute_str(state, self.key))
- session = sessionlib.object_session(instance)
+ session = sessionlib._state_session(state)
if session is None:
- try:
- session = instance_mapper.get_session()
- except exceptions.InvalidRequestError:
- raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+ raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.state_str(state), self.key))
- q = session.query(prop.mapper).autoflush(False)
+ q = session.query(prop.mapper).autoflush(False)._adapt_all_clauses()
+
if self.path:
q = q._with_current_path(self.path)
ident = []
allnulls = True
for primary_key in prop.mapper.primary_key:
- val = instance_mapper._get_committed_attr_by_column(instance, strategy._equated_columns[primary_key])
+ val = instance_mapper._get_committed_state_attr_by_column(state, strategy._equated_columns[primary_key])
allnulls = allnulls and val is None
ident.append(val)
if allnulls:
q = q._conditional_options(*self.options)
return q.get(ident)
- if strategy.order_by is not False:
- q = q.order_by(strategy.order_by)
- elif strategy.secondary is not None and strategy.secondary.default_order_by() is not None:
- q = q.order_by(strategy.secondary.default_order_by())
+ if prop.order_by is not False:
+ q = q.order_by(prop.order_by)
+ elif prop.secondary is not None and prop.secondary.default_order_by() is not None:
+ q = q.order_by(prop.secondary.default_order_by())
if self.options:
q = q._conditional_options(*self.options)
- q = q.filter(strategy.lazy_clause(instance))
+ q = q.filter(strategy.lazy_clause(state))
result = q.all()
if strategy.uselist:
self.join_depth = self.parent_property.join_depth
def init_class_attribute(self):
- # class-level eager strategy; add the PropertyLoader
- # to the parent's list of "eager loaders"; this tells the Query
- # that eager loaders will be used in a normal query
- self.parent._eager_loaders.add(self.parent_property)
-
- # initialize a lazy loader on the class level attribute
self.parent_property._get_strategy(LazyLoader).init_class_attribute()
- def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs):
+ def setup_query(self, context, entity, path, adapter, column_collection=None, parentmapper=None, **kwargs):
"""Add a left outer join to the statement thats being constructed."""
+
+ path = path + (self.key,)
+
+ # check for user-defined eager alias
+ if ("eager_row_processor", path) in context.attributes:
+ clauses = context.attributes[("eager_row_processor", path)]
+
+ adapter = entity._get_entity_clauses(context.query, context)
+ if adapter and clauses:
+ context.attributes[("eager_row_processor", path)] = clauses = adapter.wrap(clauses)
+ elif adapter:
+ context.attributes[("eager_row_processor", path)] = clauses = adapter
+
+ else:
- path = context.path
-
+ clauses = self.__create_eager_join(context, entity, path, adapter, parentmapper)
+ if not clauses:
+ return
+
+ context.attributes[("eager_row_processor", path)] = clauses
+
+ for value in self.mapper._iterate_polymorphic_properties():
+ value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=context.secondary_columns)
+
+ def __create_eager_join(self, context, entity, path, adapter, parentmapper):
# check for join_depth or basic recursion,
# if the current path was not explicitly stated as
# a desired "loaderstrategy" (i.e. via query.options())
if self.mapper.base_mapper in path:
return
- if ("eager_row_processor", path) in context.attributes:
- # if user defined eager_row_processor, that's contains_eager().
- # don't render LEFT OUTER JOIN, generate an AliasedClauses from
- # the decorator (this is a hack here, cleaned up in 0.5)
- cl = context.attributes[("eager_row_processor", path)]
- if cl:
- row = cl(None)
- class ActsLikeAliasedClauses(object):
- def aliased_column(self, col):
- return row.map[col]
- clauses = ActsLikeAliasedClauses()
- else:
- clauses = None
- else:
- clauses = self.__create_eager_join(context, path, parentclauses, parentmapper, **kwargs)
- if not clauses:
- return
-
- for value in self.mapper._iterate_polymorphic_properties():
- context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.mapper)
-
- def __create_eager_join(self, context, path, parentclauses, parentmapper, **kwargs):
if parentmapper is None:
- localparent = context.mapper
+ localparent = entity.mapper
else:
localparent = parentmapper
-
- if context.eager_joins:
- towrap = context.eager_joins
+
+ # whether or not the Query will wrap the selectable in a subquery,
+ # and then attach eager load joins to that (i.e., in the case of LIMIT/OFFSET etc.)
+ should_nest_selectable = context.query._should_nest_selectable
+
+ if entity in context.eager_joins:
+ entity_key, default_towrap = entity, entity.selectable
+ elif should_nest_selectable or not context.from_clause or not sql_util.search(context.from_clause, entity.selectable):
+ # if no from_clause, or a from_clause we can't join to, or a subquery is going to be generated,
+ # store eager joins per _MappedEntity; Query._compile_context will
+ # add them as separate selectables to the select(), or splice them together
+ # after the subquery is generated
+ entity_key, default_towrap = entity, entity.selectable
else:
- towrap = context.from_clause
-
- # create AliasedClauses object to build up the eager query. this is cached after 1st creation.
+ # otherwise, create a single eager join from the from clause.
+ # Query._compile_context will adapt as needed and append to the
+ # FROM clause of the select().
+ entity_key, default_towrap = None, context.from_clause
+
+ towrap = context.eager_joins.setdefault(entity_key, default_towrap)
+
+ # create AliasedClauses object to build up the eager query. this is cached after 1st creation.
+ # this also allows ORMJoin to cache the aliased joins it produces since we pass the same
+ # args each time in the typical case.
+ path_key = util.WeakCompositeKey(*path)
try:
- clauses = self.clauses[path]
+ clauses = self.clauses[path_key]
except KeyError:
- clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.primaryjoin, self.parent_property.secondaryjoin, parentclauses)
- self.clauses[path] = clauses
+ self.clauses[path_key] = clauses = mapperutil.ORMAdapter(mapperutil.AliasedClass(self.mapper),
+ equivalents=self.mapper._equivalent_columns,
+ chain_to=adapter)
- # place the "row_decorator" from the AliasedClauses into the QueryContext, where it will
- # be picked up in create_row_processor() when results are fetched
- context.attributes[("eager_row_processor", path)] = clauses.row_decorator
-
- if self.secondaryjoin is not None:
- context.eager_joins = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin)
-
- # TODO: check for "deferred" cols on parent/child tables here ? this would only be
- # useful if the primary/secondaryjoin are against non-PK columns on the tables (and therefore might be deferred)
-
- if self.order_by is False and self.secondary.default_order_by() is not None:
- context.eager_order_by += clauses.secondary.default_order_by()
+ if adapter:
+ if getattr(adapter, 'aliased_class', None):
+ onclause = getattr(adapter.aliased_class, self.key, self.parent_property)
+ else:
+ onclause = getattr(mapperutil.AliasedClass(self.parent, adapter.selectable), self.key, self.parent_property)
else:
- context.eager_joins = towrap.outerjoin(clauses.alias, clauses.primaryjoin)
- # ensure all the cols on the parent side are actually in the
+ onclause = self.parent_property
+
+ context.eager_joins[entity_key] = eagerjoin = mapperutil.outerjoin(towrap, clauses.aliased_class, onclause)
+
+ # send a hint to the Query as to where it may "splice" this join
+ eagerjoin.stop_on = entity.selectable
+
+ if not self.parent_property.secondary and context.query._should_nest_selectable and not parentmapper:
+ # for parentclause that is the non-eager end of the join,
+ # ensure all the parent cols in the primaryjoin are actually in the
# columns clause (i.e. are not deferred), so that aliasing applied by the Query propagates
# those columns outward. This has the effect of "undefering" those columns.
- for col in sql_util.find_columns(clauses.primaryjoin):
+ for col in sql_util.find_columns(self.parent_property.primaryjoin):
if localparent.mapped_table.c.contains_column(col):
+ if adapter:
+ col = adapter.columns[col]
context.primary_columns.append(col)
-
- if self.order_by is False and clauses.alias.default_order_by() is not None:
- context.eager_order_by += clauses.alias.default_order_by()
-
- if clauses.order_by:
- context.eager_order_by += util.to_list(clauses.order_by)
+ if self.parent_property.order_by is False:
+ if self.parent_property.secondaryjoin:
+ default_order_by = eagerjoin.left.right.default_order_by()
+ else:
+ default_order_by = eagerjoin.right.default_order_by()
+ if default_order_by:
+ context.eager_order_by += default_order_by
+ elif self.parent_property.order_by:
+ context.eager_order_by += eagerjoin._target_adapter.copy_and_process(util.to_list(self.parent_property.order_by))
+
return clauses
- def _create_row_decorator(self, selectcontext, row, path):
- """Create a *row decorating* function that will apply eager
- aliasing to the row.
-
- Also check that an identity key can be retrieved from the row,
- else return None.
- """
-
- #print "creating row decorator for path ", "->".join([str(s) for s in path])
-
- if ("eager_row_processor", path) in selectcontext.attributes:
- decorator = selectcontext.attributes[("eager_row_processor", path)]
- if decorator is None:
- decorator = lambda row: row
+ def __create_eager_adapter(self, context, row, adapter, path):
+ if ("eager_row_processor", path) in context.attributes:
+ decorator = context.attributes[("eager_row_processor", path)]
else:
if self._should_log_debug:
self.logger.debug("Could not locate aliased clauses for key: " + str(path))
- return None
+ return False
+ if adapter and decorator:
+ decorator = adapter.wrap(decorator)
+ elif adapter:
+ decorator = adapter
+
try:
- decorated_row = decorator(row)
- # check for identity key
- identity_key = self.mapper.identity_key_from_row(decorated_row)
- # and its good
+ identity_key = self.mapper.identity_key_from_row(row, decorator)
return decorator
except KeyError, k:
# no identity key - dont return a row processor, will cause a degrade to lazy
if self._should_log_debug:
- self.logger.debug("could not locate identity key from row '%s'; missing column '%s'" % (repr(decorated_row), str(k)))
- return None
-
- def create_row_processor(self, selectcontext, mapper, row):
+ self.logger.debug("could not locate identity key from row; missing column '%s'" % k)
+ return False
- row_decorator = self._create_row_decorator(selectcontext, row, selectcontext.path)
- pathstr = ','.join([str(x) for x in selectcontext.path])
- if row_decorator is not None:
- def execute(instance, row, isnew, **flags):
- decorated_row = row_decorator(row)
-
- if not self.uselist:
- if self._should_log_debug:
- self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
+ def create_row_processor(self, context, path, mapper, row, adapter):
+ path = path + (self.key,)
+ eager_adapter = self.__create_eager_adapter(context, row, adapter, path)
+
+ if eager_adapter is not False:
+ key = self.key
+ _instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter)
+
+ if not self.uselist:
+ def execute(state, row, isnew, **flags):
if isnew:
# set a scalar object instance directly on the
# parent object, bypassing InstrumentedAttribute
# event handlers.
- #
- instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None)
+ state.dict[key] = _instance(row, None)
else:
# call _instance on the row, even though the object has been created,
# so that we further descend into properties
- self.mapper._instance(selectcontext, decorated_row, None)
- else:
- if isnew or self.key not in instance._state.appenders:
- # appender_key can be absent from selectcontext.attributes with isnew=False
+ _instance(row, None)
+ else:
+ def execute(state, row, isnew, **flags):
+ if isnew or (state, key) not in context.attributes:
+ # appender_key can be absent from context.attributes with isnew=False
# when self-referential eager loading is used; the same instance may be present
# in two distinct sets of result columns
-
- if self._should_log_debug:
- self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key))
- collection = attributes.init_collection(instance, self.key)
+ collection = attributes.init_collection(state, key)
appender = util.UniqueAppender(collection, 'append_without_event')
- instance._state.appenders[self.key] = appender
-
- result_list = instance._state.appenders[self.key]
- if self._should_log_debug:
- self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
+ context.attributes[(state, key)] = appender
+
+ result_list = context.attributes[(state, key)]
- self.mapper._instance(selectcontext, decorated_row, result_list)
+ _instance(row, result_list)
if self._should_log_debug:
- self.logger.debug("Returning eager instance loader for %s" % str(self))
+ execute = self.debug_callable(execute, self.logger,
+ "%s returning eager instance loader" % self,
+ lambda state, row, isnew, **flags: "%s eagerload %s" % (self, self.uselist and "scalar attribute" or "collection")
+ )
- return (execute, execute, None)
+ return (execute, execute)
else:
if self._should_log_debug:
- self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
- return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
+ self.logger.debug("%s degrading to lazy loader" % self)
+ return self.parent_property._get_strategy(LazyLoader).create_row_processor(context, path, mapper, row, adapter)
- def __str__(self):
- return str(self.parent) + "." + self.key
-
-EagerLoader.logger = logging.class_logger(EagerLoader)
+EagerLoader.logger = log.class_logger(EagerLoader)
class EagerLazyOption(StrategizedOption):
def __init__(self, key, lazy=True, chained=False, mapper=None):
def is_chained(self):
return not self.lazy and self.chained
- def process_query_property(self, query, paths):
- if self.lazy:
- if paths[-1] in query._eager_loaders:
- query._eager_loaders = query._eager_loaders.difference(util.Set([paths[-1]]))
- else:
- if not self.chained:
- paths = [paths[-1]]
- res = util.Set()
- for path in paths:
- if len(path) - len(query._current_path) == 2:
- res.add(path)
- query._eager_loaders = query._eager_loaders.union(res)
- super(EagerLazyOption, self).process_query_property(query, paths)
-
def get_strategy_class(self):
if self.lazy:
return LazyLoader
elif self.lazy is None:
return NoLoader
-EagerLazyOption.logger = logging.class_logger(EagerLazyOption)
-
-class RowDecorateOption(PropertyOption):
- def __init__(self, key, decorator=None, alias=None):
- super(RowDecorateOption, self).__init__(key)
- self.decorator = decorator
+class LoadEagerFromAliasOption(PropertyOption):
+ def __init__(self, key, alias=None):
+ super(LoadEagerFromAliasOption, self).__init__(key)
+ if alias:
+ if not isinstance(alias, basestring):
+ m, alias, is_aliased_class = mapperutil._entity_info(alias)
self.alias = alias
def process_query_property(self, query, paths):
- if self.alias is not None and self.decorator is None:
- (mapper, propname) = paths[-1][-2:]
-
- prop = mapper.get_property(propname, resolve_synonyms=True)
+ if self.alias:
if isinstance(self.alias, basestring):
- self.alias = prop.target.alias(self.alias)
+ (mapper, propname) = paths[-1][-2:]
- self.decorator = mapperutil.create_row_adapter(self.alias)
- query._attributes[("eager_row_processor", paths[-1])] = self.decorator
+ prop = mapper.get_property(propname, resolve_synonyms=True)
+ self.alias = prop.target.alias(self.alias)
+ if not isinstance(self.alias, expression.Alias):
+ import pdb
+ pdb.set_trace()
+ query._attributes[("eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias)
+ else:
+ query._attributes[("eager_row_processor", paths[-1])] = None
-RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
based on join conditions.
"""
-from sqlalchemy import schema, exceptions, util
-from sqlalchemy.sql import visitors, operators, util as sqlutil
-from sqlalchemy import logging
-from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY # legacy
+from sqlalchemy.orm import exc, util as mapperutil
def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(source, l)
- except exceptions.UnmappedColumnError:
+ except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
try:
dest_mapper._set_state_attr_by_column(dest, r, value)
- except exceptions.UnmappedColumnError:
- self._raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
+ except exc.UnmappedColumnError:
+ _raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
def clear(dest, dest_mapper, synchronize_pairs):
for l, r in synchronize_pairs:
if r.primary_key:
- raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest)))
+ raise AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest)))
try:
dest_mapper._set_state_attr_by_column(dest, r, None)
- except exceptions.UnmappedColumnError:
+ except exc.UnmappedColumnError:
_raise_col_to_prop(True, None, l, dest_mapper, r)
def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
try:
oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
value = source_mapper._get_state_attr_by_column(source, l)
- except exceptions.UnmappedColumnError:
- self._raise_col_to_prop(False, source_mapper, l, None, r)
+ except exc.UnmappedColumnError:
+ _raise_col_to_prop(False, source_mapper, l, None, r)
dest[r.key] = value
dest[old_prefix + r.key] = oldvalue
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(source, l)
- except exceptions.UnmappedColumnError:
+ except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
-
+
dict_[r.key] = value
def source_changes(uowcommit, source, source_mapper, synchronize_pairs):
for l, r in synchronize_pairs:
try:
prop = source_mapper._get_col_to_prop(l)
- except exceptions.UnmappedColumnError:
+ except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
(added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True)
if added and deleted:
for l, r in synchronize_pairs:
try:
prop = dest_mapper._get_col_to_prop(r)
- except exceptions.UnmappedColumnError:
+ except exc.UnmappedColumnError:
_raise_col_to_prop(True, None, l, dest_mapper, r)
(added, unchanged, deleted) = uowcommit.get_attribute_history(dest, prop.key, passive=True)
if added and deleted:
def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
if isdest:
- raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper))
+ raise exc.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper))
else:
- raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column))
-
+ raise exc.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column))
pattern. The Unit of Work then maintains lists of objects that are
new, dirty, or deleted and provides the capability to flush all those
changes at once.
+
"""
-import StringIO, weakref
-from sqlalchemy import util, logging, topological, exceptions
+import StringIO
+
+from sqlalchemy import util, log, topological
from sqlalchemy.orm import attributes, interfaces
from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.mapper import object_mapper, _state_mapper, has_identity
+from sqlalchemy.orm.mapper import _state_mapper
# Load lazily
object_session = None
+_state_session = None
class UOWEventHandler(interfaces.AttributeExtension):
"""An event handler added to all relation attributes which handles
self.class_ = class_
self.cascade = cascade
- def _target_mapper(self, obj):
- prop = object_mapper(obj).get_property(self.key)
+ def _target_mapper(self, state):
+ prop = _state_mapper(state).get_property(self.key)
return prop.mapper
- def append(self, obj, item, initiator):
+ def append(self, state, item, initiator):
# process "save_update" cascade rules for when an instance is appended to the list of another instance
- sess = object_session(obj)
+ sess = _state_session(state)
if sess:
if self.cascade.save_update and item not in sess:
- sess.save_or_update(item, entity_name=self._target_mapper(obj).entity_name)
+ sess.save_or_update(item, entity_name=self._target_mapper(state).entity_name)
- def remove(self, obj, item, initiator):
- sess = object_session(obj)
+ def remove(self, state, item, initiator):
+ sess = _state_session(state)
if sess:
# expunge pending orphans
if self.cascade.delete_orphan and item in sess.new:
- if self._target_mapper(obj)._is_orphan(item):
+ if self._target_mapper(state)._is_orphan(attributes.instance_state(item)):
sess.expunge(item)
- def set(self, obj, newvalue, oldvalue, initiator):
+ def set(self, state, newvalue, oldvalue, initiator):
# process "save_update" cascade rules for when an instance is attached to another instance
if oldvalue is newvalue:
return
- sess = object_session(obj)
+ sess = _state_session(state)
if sess:
if newvalue is not None and self.cascade.save_update and newvalue not in sess:
- sess.save_or_update(newvalue, entity_name=self._target_mapper(obj).entity_name)
+ sess.save_or_update(newvalue, entity_name=self._target_mapper(state).entity_name)
if self.cascade.delete_orphan and oldvalue in sess.new:
sess.expunge(oldvalue)
-class UnitOfWork(object):
- """Main UOW object which stores lists of dirty/new/deleted objects.
-
- Provides top-level *flush* functionality as well as the
- default transaction boundaries involved in a write
- operation.
- """
-
- def __init__(self, session):
- if session.weak_identity_map:
- self.identity_map = attributes.WeakInstanceDict()
- else:
- self.identity_map = attributes.StrongInstanceDict()
-
- self.new = {} # InstanceState->object, strong refs object
- self.deleted = {} # same
- self.logger = logging.instance_logger(self, echoflag=session.echo_uow)
-
- def _remove_deleted(self, state):
- if '_instance_key' in state.dict:
- del self.identity_map[state.dict['_instance_key']]
- self.deleted.pop(state, None)
- self.new.pop(state, None)
-
- def _is_valid(self, state):
- if '_instance_key' in state.dict:
- return state.dict['_instance_key'] in self.identity_map
- else:
- return state in self.new
-
- def _register_clean(self, state):
- """register the given object as 'clean' (i.e. persistent) within this unit of work, after
- a save operation has taken place."""
-
- mapper = _state_mapper(state)
- instance_key = mapper._identity_key_from_state(state)
-
- if '_instance_key' not in state.dict:
- state.dict['_instance_key'] = instance_key
-
- elif state.dict['_instance_key'] != instance_key:
- # primary key switch
- del self.identity_map[state.dict['_instance_key']]
- state.dict['_instance_key'] = instance_key
-
- if hasattr(state, 'insert_order'):
- delattr(state, 'insert_order')
-
- o = state.obj()
- # prevent against last minute dereferences of the object
- # TODO: identify a code path where state.obj() is None
- if o is not None:
- self.identity_map[state.dict['_instance_key']] = o
- state.commit_all()
-
- # remove from new last, might be the last strong ref
- self.new.pop(state, None)
-
- def register_new(self, obj):
- """register the given object as 'new' (i.e. unsaved) within this unit of work."""
-
- if hasattr(obj, '_instance_key'):
- raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj))
- if obj._state not in self.new:
- self.new[obj._state] = obj
- obj._state.insert_order = len(self.new)
-
- def register_deleted(self, obj):
- """register the given persistent object as 'to be deleted' within this unit of work."""
-
- self.deleted[obj._state] = obj
-
- def locate_dirty(self):
- """return a set of all persistent instances within this unit of work which
- either contain changes or are marked as deleted.
- """
-
- # a little bit of inlining for speed
- return util.IdentitySet([x for x in self.identity_map.values()
- if x._state not in self.deleted
- and (
- x._state.modified
- or (x.__class__._class_state.has_mutable_scalars and x._state.is_modified())
- )
- ])
-
- def flush(self, session, objects=None):
- """create a dependency tree of all pending SQL operations within this unit of work and execute."""
-
- dirty = [x for x in self.identity_map.all_states()
- if x.modified
- or (x.class_._class_state.has_mutable_scalars and x.is_modified())
- ]
-
- if not dirty and not self.deleted and not self.new:
- return
-
- deleted = util.Set(self.deleted)
- new = util.Set(self.new)
-
- dirty = util.Set(dirty).difference(deleted)
-
- flush_context = UOWTransaction(self, session)
-
- if session.extension is not None:
- session.extension.before_flush(session, flush_context, objects)
-
- # create the set of all objects we want to operate upon
- if objects:
- # specific list passed in
- objset = util.Set([o._state for o in objects])
- else:
- # or just everything
- objset = util.Set(self.identity_map.all_states()).union(new)
-
- # store objects whose fate has been decided
- processed = util.Set()
-
- # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted.
- for state in new.union(dirty).intersection(objset).difference(deleted):
- if state in processed:
- continue
-
- obj = state.obj()
- is_orphan = _state_mapper(state)._is_orphan(obj)
- if is_orphan and not has_identity(obj):
- raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" %
- (
- obj,
- ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans])
- ))
- flush_context.register_object(state, isdelete=is_orphan)
- processed.add(state)
-
- # put all remaining deletes into the flush context.
- for state in deleted.intersection(objset).difference(processed):
- flush_context.register_object(state, isdelete=True)
-
- if len(flush_context.tasks) == 0:
- return
-
- session.create_transaction(autoflush=False)
- flush_context.transaction = session.transaction
- try:
- flush_context.execute()
-
- if session.extension is not None:
- session.extension.after_flush(session, flush_context)
- session.commit()
- except:
- session.rollback()
- raise
-
- flush_context.post_exec()
-
- if session.extension is not None:
- session.extension.after_flush_postexec(session, flush_context)
-
- def prune_identity_map(self):
- """Removes unreferenced instances cached in a strong-referencing identity map.
-
- Note that this method is only meaningful if "weak_identity_map"
- on the parent Session is set to False and therefore this UnitOfWork's
- identity map is a regular dictionary
-
- Removes any object in the identity map that is not referenced
- in user code or scheduled for a unit of work operation. Returns
- the number of objects pruned.
- """
-
- if isinstance(self.identity_map, attributes.WeakInstanceDict):
- return 0
- ref_count = len(self.identity_map)
- dirty = self.locate_dirty()
- keepers = weakref.WeakValueDictionary(self.identity_map)
- self.identity_map.clear()
- self.identity_map.update(keepers)
- return ref_count - len(self.identity_map)
class UOWTransaction(object):
"""Handles the details of organizing and executing transaction
packages.
"""
- def __init__(self, uow, session):
- self.uow = uow
+ def __init__(self, session):
self.session = session
self.mapper_flush_opts = session._mapper_flush_opts
# information.
self.attributes = {}
- self.logger = logging.instance_logger(self, echoflag=session.echo_uow)
+ self.logger = log.instance_logger(self, echoflag=session.echo_uow)
def get_attribute_history(self, state, key, passive=True):
hashkey = ("history", state, key)
(added, unchanged, deleted) = attributes.get_history(state, key, passive=passive)
self.attributes[hashkey] = (added, unchanged, deleted, passive)
- if added is None:
+ if added is None or not state.get_impl(key).uses_objects:
return (added, unchanged, deleted)
else:
return (
- [getattr(c, '_state', c) for c in added],
- [getattr(c, '_state', c) for c in unchanged],
- [getattr(c, '_state', c) for c in deleted],
+ [c is not None and attributes.instance_state(c) or None for c in added],
+ [c is not None and attributes.instance_state(c) or None for c in unchanged],
+ [c is not None and attributes.instance_state(c) or None for c in deleted],
)
-
- def register_object(self, state, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs):
+ def register_object(self, state, isdelete=False, listonly=False, postupdate=False, post_update_cols=None):
# if object is not in the overall session, do nothing
- if not self.uow._is_valid(state):
+ if not self.session._contains_state(state):
if self._should_log_debug:
self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.state_str(state)))
return
self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.state_str(state), isdelete, listonly, postupdate))
mapper = _state_mapper(state)
-
+
task = self.get_task_by_mapper(mapper)
if postupdate:
task.append_postupdate(state, post_update_cols)
else:
- task.append(state, listonly, isdelete=isdelete, **kwargs)
+ task.append(state, listonly=listonly, isdelete=isdelete)
def set_row_switch(self, state):
"""mark a deleted object as a 'row switch'.
import uowdumper
uowdumper.UOWDumper(tasks, buf)
return buf.getvalue()
-
- def post_exec(self):
+
+ def elements(self):
+ """return an iterator of all UOWTaskElements within this UOWTransaction."""
+ for task in self.tasks.values():
+ for elem in task.elements:
+ yield elem
+ elements = property(elements)
+
+ def finalize_flush_changes(self):
"""mark processed objects as clean / deleted after a successful flush().
this method is called within the flush() method after the
execute() method has succeeded and the transaction has been committed.
"""
- for task in self.tasks.values():
- for elem in task.elements:
- if elem.state is None:
- continue
- if elem.isdelete:
- self.uow._remove_deleted(elem.state)
- else:
- self.uow._register_clean(elem.state)
+ for elem in self.elements:
+ if elem.isdelete:
+ self.session._remove_newly_deleted(elem.state)
+ else:
+ self.session._register_newly_persistent(elem.state)
def _sort_dependencies(self):
nodes = topological.sort_with_cycles(self.dependencies,
class UOWTask(object):
"""Represents all of the objects in the UOWTransaction which correspond to
- a particular mapper. This is the primary class of three classes used to generate
- the elements of the dependency graph.
+ a particular mapper.
+
"""
-
def __init__(self, uowtransaction, mapper, base_task=None):
self.uowtransaction = uowtransaction
# mapping of InstanceState -> UOWTaskElement
self._objects = {}
+ self.dependent_tasks = []
self.dependencies = util.Set()
self.cyclical_dependencies = util.Set()
rec.update(listonly, isdelete)
- def _append_cyclical_childtask(self, task):
- if "cyclical" not in self._objects:
- self._objects["cyclical"] = UOWTaskElement(None)
- self._objects["cyclical"].childtasks.append(task)
-
def append_postupdate(self, state, post_update_cols):
"""issue a 'post update' UPDATE statement via this object's mapper immediately.
"""
# postupdates are UPDATED immeditely (for now)
- # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns
- # instead of __eq__
+ # convert post_update_cols list to a Set so that __hash__() is used to compare columns
+ # instead of __eq__()
self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols))
def __contains__(self, state):
for rec in callable(task):
yield rec
return property(collection)
-
- elements = property(lambda self:self._objects.values())
- polymorphic_elements = _polymorphic_collection(lambda task:task.elements)
-
- polymorphic_tosave_elements = property(lambda self: [rec for rec in self.polymorphic_elements
- if not rec.isdelete])
-
- polymorphic_todelete_elements = property(lambda self:[rec for rec in self.polymorphic_elements
- if rec.isdelete])
+ def _elements(self):
+ return self._objects.values()
+ elements = property(_elements)
+
+ polymorphic_elements = _polymorphic_collection(_elements)
- polymorphic_tosave_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements
- if rec.state is not None and not rec.listonly and rec.isdelete is False])
+ def polymorphic_tosave_elements(self):
+ return [rec for rec in self.polymorphic_elements if not rec.isdelete]
+ polymorphic_tosave_elements = property(polymorphic_tosave_elements)
+
+ def polymorphic_todelete_elements(self):
+ return [rec for rec in self.polymorphic_elements if rec.isdelete]
+ polymorphic_todelete_elements = property(polymorphic_todelete_elements)
+
+ def polymorphic_tosave_objects(self):
+ return [
+ rec.state for rec in self.polymorphic_elements
+ if rec.state is not None and not rec.listonly and rec.isdelete is False
+ ]
+ polymorphic_tosave_objects = property(polymorphic_tosave_objects)
- polymorphic_todelete_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements
- if rec.state is not None and not rec.listonly and rec.isdelete is True])
+ def polymorphic_todelete_objects(self):
+ return [
+ rec.state for rec in self.polymorphic_elements
+ if rec.state is not None and not rec.listonly and rec.isdelete is True
+ ]
+ polymorphic_todelete_objects = property(polymorphic_todelete_objects)
- polymorphic_dependencies = _polymorphic_collection(lambda task:task.dependencies)
+ def polymorphic_dependencies(self):
+ return self.dependencies
+ polymorphic_dependencies = _polymorphic_collection(polymorphic_dependencies)
- polymorphic_cyclical_dependencies = _polymorphic_collection(lambda task:task.cyclical_dependencies)
+ def polymorphic_cyclical_dependencies(self):
+ return self.cyclical_dependencies
+ polymorphic_cyclical_dependencies = _polymorphic_collection(polymorphic_cyclical_dependencies)
def _sort_circular_dependencies(self, trans, cycles):
"""Create a hierarchical tree of *subtasks*
if t is None:
t = UOWTask(self.uowtransaction, originating_task.mapper)
nexttasks[originating_task] = t
- parenttask._append_cyclical_childtask(t)
+ parenttask.dependent_tasks.append(t)
t.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete)
if state in dependencies:
return ret
def __repr__(self):
- if self.mapper is not None:
- if self.mapper.__class__.__name__ == 'Mapper':
- name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.description
- else:
- name = repr(self.mapper)
- else:
- name = '(none)'
- return ("UOWTask(%s) Mapper: '%s'" % (hex(id(self)), name))
+ return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper))
class UOWTaskElement(object):
- """An element within a UOWTask.
-
- Corresponds to a single object instance to be saved, deleted, or
- just part of the transaction as a placeholder for further
- dependencies (i.e. 'listonly').
-
- may also store additional sub-UOWTasks.
+ """Corresponds to a single InstanceState to be saved, deleted,
+ or otherwise marked as having dependencies. A collection of
+ UOWTaskElements are held by a UOWTask.
+
"""
-
def __init__(self, state):
self.state = state
self.listonly = True
- self.childtasks = []
self.isdelete = False
self.__preprocessed = {}
class UOWDependencyProcessor(object):
"""In between the saving and deleting of objects, process
- *dependent* data, such as filling in a foreign key on a child item
+ dependent data, such as filling in a foreign key on a child item
from a new primary key, or deleting association rows before a
delete. This object acts as a proxy to a DependencyProcessor.
+
"""
-
def __init__(self, processor, targettask):
self.processor = processor
self.targettask = targettask
return elem.state
ret = False
- elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None and not elem.is_preprocessed(self)]
+ elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if not elem.is_preprocessed(self)]
if elements:
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
- elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None and not elem.is_preprocessed(self)]
+ elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if not elem.is_preprocessed(self)]
if elements:
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
"""process all objects contained within this ``UOWDependencyProcessor``s target task."""
if not delete:
- self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None], trans, delete=False)
+ self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements], trans, delete=False)
else:
- self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None], trans, delete=True)
+ self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements], trans, delete=True)
def get_object_dependencies(self, state, trans, passive):
return trans.get_attribute_history(state, self.processor.key, passive=passive)
when toplogically sorting on a per-instance basis.
"""
-
return self.processor.whose_dependent_on_who(state1, state2)
def branch(self, task):
is broken up into many individual ``UOWTask`` objects.
"""
-
return UOWDependencyProcessor(self.processor, task)
def execute_save_steps(self, trans, task):
self.save_objects(trans, task)
self.execute_cyclical_dependencies(trans, task, False)
- self.execute_per_element_childtasks(trans, task, False)
self.execute_dependencies(trans, task, False)
self.execute_dependencies(trans, task, True)
-
+
def execute_delete_steps(self, trans, task):
self.execute_cyclical_dependencies(trans, task, True)
- self.execute_per_element_childtasks(trans, task, True)
self.delete_objects(trans, task)
def execute_dependencies(self, trans, task, isdelete=None):
def execute_cyclical_dependencies(self, trans, task, isdelete):
for dep in task.polymorphic_cyclical_dependencies:
self.execute_dependency(trans, dep, isdelete)
-
- def execute_per_element_childtasks(self, trans, task, isdelete):
- for element in task.polymorphic_tosave_elements + task.polymorphic_todelete_elements:
- self.execute_element_childtasks(trans, element, isdelete)
-
- def execute_element_childtasks(self, trans, element, isdelete):
- for child in element.childtasks:
- self.execute(trans, [child], isdelete)
-
+ for t in task.dependent_tasks:
+ self.execute(trans, [t], isdelete)
"""Dumps out a string representation of a UOWTask structure"""
+from sqlalchemy import util
from sqlalchemy.orm import unitofwork
from sqlalchemy.orm import util as mapperutil
-from sqlalchemy import util
class UOWDumper(unitofwork.UOWExecutor):
- def __init__(self, tasks, buf, verbose=False):
- self.verbose = verbose
+ def __init__(self, tasks, buf):
self.indent = 0
self.tasks = tasks
self.buf = buf
- self.headers = {}
self.execute(None, tasks)
def execute(self, trans, tasks, isdelete=None):
for rec in l:
if rec.listonly:
continue
- self.header("Save elements"+ self._inheritance_tag(task))
self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n")
- self.closeheader()
def delete_objects(self, trans, task):
for rec in task.polymorphic_todelete_elements:
if rec.listonly:
continue
- self.header("Delete elements"+ self._inheritance_tag(task))
self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n")
- self.closeheader()
-
- def _inheritance_tag(self, task):
- if not self.verbose:
- return ""
- else:
- return (" (inheriting task %s)" % self._repr_task(task))
-
- def header(self, text):
- """Write a given header just once."""
-
- if not self.verbose:
- return
- try:
- self.headers[text]
- except KeyError:
- self.buf.write(self._indent() + "- " + text + "\n")
- self.headers[text] = True
-
- def closeheader(self):
- if not self.verbose:
- return
- self.buf.write(self._indent() + "- ------\n")
def execute_dependency(self, transaction, dep, isdelete):
self._dump_processor(dep, isdelete)
- def execute_save_steps(self, trans, task):
- super(UOWDumper, self).execute_save_steps(trans, task)
-
- def execute_delete_steps(self, trans, task):
- super(UOWDumper, self).execute_delete_steps(trans, task)
-
- def execute_dependencies(self, trans, task, isdelete=None):
- super(UOWDumper, self).execute_dependencies(trans, task, isdelete)
-
- def execute_cyclical_dependencies(self, trans, task, isdelete):
- self.header("Cyclical %s dependencies" % (isdelete and "delete" or "save"))
- super(UOWDumper, self).execute_cyclical_dependencies(trans, task, isdelete)
- self.closeheader()
-
- def execute_per_element_childtasks(self, trans, task, isdelete):
- super(UOWDumper, self).execute_per_element_childtasks(trans, task, isdelete)
-
- def execute_element_childtasks(self, trans, element, isdelete):
- self.header("%s subelements of UOWTaskElement(%s)" % ((isdelete and "Delete" or "Save"), hex(id(element))))
- super(UOWDumper, self).execute_element_childtasks(trans, element, isdelete)
- self.closeheader()
-
def _dump_processor(self, proc, deletes):
if deletes:
val = proc.targettask.polymorphic_todelete_elements
else:
val = proc.targettask.polymorphic_tosave_elements
- if self.verbose:
- self.buf.write(self._indent() + " +- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % (
- repr(proc.processor.key),
- ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")),
- hex(id(proc)),
- self._repr_task(proc.targettask))
- )
- elif False:
- self.buf.write(self._indent() + " +- %s attribute on %s\n" % (
- repr(proc.processor.key),
- ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")),
- )
- )
-
- if len(val) == 0:
- if self.verbose:
- self.buf.write(self._indent() + " +- " + "(no objects)\n")
for v in val:
self.buf.write(self._indent() + " +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n")
objid = "%s.%s" % (mapperutil.state_str(te.state), attribute)
else:
objid = mapperutil.state_str(te.state)
- if self.verbose:
- return "%s (UOWTaskElement(%s, %s))" % (objid, hex(id(te)), (te.listonly and 'listonly' or (te.isdelete and 'delete' or 'save')))
- elif process:
+ if process:
return "Process %s" % (objid)
else:
return "%s %s" % ((te.isdelete and "Delete" or "Save"), objid)
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions
-from sqlalchemy.sql import util as sql_util
-from sqlalchemy.sql.util import row_adapter as create_row_adapter
-from sqlalchemy.sql import visitors
-from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator
+import new
-all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import sql, util
+from sqlalchemy.sql import expression, util as sql_util, operators
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty
+from sqlalchemy.orm import attributes
+
+all_cascades = util.FrozenSet(["delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire", "none"])
+_INSTRUMENTOR = ('mapper', 'instrumentor')
+
class CascadeOptions(object):
"""Keeps track of the options sent to relation().cascade"""
self.refresh_expire = "refresh-expire" in values or "all" in values
for x in values:
if x not in all_cascades:
- raise exceptions.ArgumentError("Invalid cascade option '%s'" % x)
+ raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
def __contains__(self, item):
return getattr(self, item.replace("-", "_"), False)
result.append(sql.select([col(name, table) for name in colnames], from_obj=[table]))
return sql.union_all(*result).alias(aliasname)
+def identity_key(*args, **kwargs):
+ """Get an identity key.
-class ExtensionCarrier(object):
- """stores a collection of MapperExtension objects.
-
- allows an extension methods to be called on contained MapperExtensions
- in the order they were added to this object. Also includes a 'methods' dictionary
- accessor which allows for a quick check if a particular method
- is overridden on any contained MapperExtensions.
+ Valid call signatures:
+
+ * ``identity_key(class, ident, entity_name=None)``
+
+ class
+ mapped class (must be a positional argument)
+
+ ident
+ primary key, if the key is composite this is a tuple
+
+ entity_name
+ optional entity name
+
+ * ``identity_key(instance=instance)``
+
+ instance
+ object instance (must be given as a keyword arg)
+
+ * ``identity_key(class, row=row, entity_name=None)``
+
+ class
+ mapped class (must be a positional argument)
+
+ row
+ result proxy row (must be given as a keyword arg)
+
+ entity_name
+ optional entity name (must be given as a keyword arg)
"""
+ from sqlalchemy.orm import class_mapper, object_mapper
+ if args:
+ if len(args) == 1:
+ class_ = args[0]
+ try:
+ row = kwargs.pop("row")
+ except KeyError:
+ ident = kwargs.pop("ident")
+ entity_name = kwargs.pop("entity_name", None)
+ elif len(args) == 2:
+ class_, ident = args
+ entity_name = kwargs.pop("entity_name", None)
+ elif len(args) == 3:
+ class_, ident, entity_name = args
+ else:
+ raise sa_exc.ArgumentError("expected up to three "
+ "positional arguments, got %s" % len(args))
+ if kwargs:
+ raise sa_exc.ArgumentError("unknown keyword arguments: %s"
+ % ", ".join(kwargs.keys()))
+ mapper = class_mapper(class_, entity_name=entity_name)
+ if "ident" in locals():
+ return mapper.identity_key_from_primary_key(ident)
+ return mapper.identity_key_from_row(row)
+ instance = kwargs.pop("instance")
+ if kwargs:
+ raise sa_exc.ArgumentError("unknown keyword arguments: %s"
+ % ", ".join(kwargs.keys()))
+ mapper = object_mapper(instance)
+ return mapper.identity_key_from_instance(instance)
- def __init__(self, _elements=None):
+class ExtensionCarrier(object):
+ """Fronts an ordered collection of MapperExtension objects.
+
+ Bundles multiple MapperExtensions into a unified callable unit,
+ encapsulating ordering, looping and EXT_CONTINUE logic. The
+ ExtensionCarrier implements the MapperExtension interface, e.g.::
+
+ carrier.after_insert(...args...)
+
+ Also includes a 'methods' dictionary accessor which allows for a quick
+ check if a particular method is overridden on any contained
+ MapperExtensions.
+
+ """
+
+ interface = util.Set([method for method in dir(MapperExtension)
+ if not method.startswith('_')])
+
+ def __init__(self, extensions=None):
self.methods = {}
- if _elements is not None:
- self.__elements = [self.__inspect(e) for e in _elements]
- else:
- self.__elements = []
-
- def copy(self):
- return ExtensionCarrier(list(self.__elements))
-
- def __iter__(self):
- return iter(self.__elements)
+ self._extensions = []
+ for ext in extensions or ():
+ self.append(ext)
- def insert(self, extension):
- """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
+ def copy(self):
+ return ExtensionCarrier(self._extensions)
- self.__elements.insert(0, self.__inspect(extension))
+ def push(self, extension):
+ """Insert a MapperExtension at the beginning of the collection."""
+ self._register(extension)
+ self._extensions.insert(0, extension)
def append(self, extension):
- """Append a MapperExtension at the end of this ExtensionCarrier's list."""
+ """Append a MapperExtension at the end of the collection."""
+ self._register(extension)
+ self._extensions.append(extension)
- self.__elements.append(self.__inspect(extension))
+ def __iter__(self):
+ """Iterate over MapperExtensions in the collection."""
+ return iter(self._extensions)
+
+ def _register(self, extension):
+ """Register callable fronts for overridden interface methods."""
+ for method in self.interface:
+ if method in self.methods:
+ continue
+ impl = getattr(extension, method, None)
+ if impl and impl is not getattr(MapperExtension, method):
+ self.methods[method] = self._create_do(method)
+
+ def _create_do(self, method):
+ """Return a closure that loops over impls of the named method."""
- def __inspect(self, extension):
- for meth in MapperExtension.__dict__.keys():
- if meth not in self.methods and hasattr(extension, meth) and getattr(extension, meth) is not getattr(MapperExtension, meth):
- self.methods[meth] = self.__create_do(meth)
- return extension
-
- def __create_do(self, funcname):
def _do(*args, **kwargs):
- for elem in self.__elements:
- ret = getattr(elem, funcname)(*args, **kwargs)
+ for ext in self._extensions:
+ ret = getattr(ext, method)(*args, **kwargs)
if ret is not EXT_CONTINUE:
return ret
else:
return EXT_CONTINUE
-
try:
- _do.__name__ = funcname
+ _do.__name__ = method.im_func.func_name
except:
- # cant set __name__ in py 2.3
pass
return _do
-
- def _pass(self, *args, **kwargs):
+
+ def _pass(*args, **kwargs):
return EXT_CONTINUE
-
+ _pass = staticmethod(_pass)
+
def __getattr__(self, key):
+ """Delegate MapperExtension methods to bundled fronts."""
+ if key not in self.interface:
+ raise AttributeError(key)
return self.methods.get(key, self._pass)
-class AliasedClauses(object):
- """Creates aliases of a mapped tables for usage in ORM queries, and provides expression adaptation."""
-
- def __init__(self, alias, equivalents=None, chain_to=None, should_adapt=True):
- self.alias = alias
- self.equivalents = equivalents
- self.row_decorator = self._create_row_adapter()
- self.should_adapt = should_adapt
- if should_adapt:
- self.adapter = sql_util.ClauseAdapter(self.alias, equivalents=equivalents)
+class ORMAdapter(sql_util.ColumnAdapter):
+ def __init__(self, entity, equivalents=None, chain_to=None):
+ mapper, selectable, is_aliased_class = _entity_info(entity)
+ if is_aliased_class:
+ self.aliased_class = entity
else:
- self.adapter = visitors.NullVisitor()
-
- if chain_to:
- self.adapter.chain(chain_to.adapter)
-
- def aliased_column(self, column):
- if not self.should_adapt:
- return column
-
- conv = self.alias.corresponding_column(column)
- if conv:
- return conv
-
- # process column-level subqueries
- aliased_column = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).traverse(column, clone=True)
-
- # anonymize labels which might have specific names
- if isinstance(aliased_column, expression._Label):
- aliased_column = aliased_column.label(None)
-
- # add to row decorator explicitly
- self.row_decorator({}).map[column] = aliased_column
- return aliased_column
-
- def adapt_clause(self, clause):
- return self.adapter.traverse(clause, clone=True)
-
- def adapt_list(self, clauses):
- return self.adapter.copy_and_process(clauses)
-
- def _create_row_adapter(self):
- return create_row_adapter(self.alias, equivalent_columns=self.equivalents)
+ self.aliased_class = None
+ sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to)
+class AliasedClass(object):
+ def __init__(self, cls, alias=None, name=None):
+ self.__mapper = _class_to_mapper(cls)
+ self.__target = self.__mapper.class_
+ alias = alias or self.__mapper._with_polymorphic_selectable.alias()
+ self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
+ self.__alias = alias
+ self._sa_label_name = name
+ self.__name__ = 'AliasedClass_' + str(self.__target)
+
+ def __adapt_prop(self, prop):
+ existing = getattr(self.__target, prop.key)
+ comparator = AliasedComparator(self, self.__adapter, existing.comparator)
+ queryattr = attributes.QueryableAttribute(
+ existing.impl, parententity=self, comparator=comparator)
+ setattr(self, prop.key, queryattr)
+ return queryattr
-class PropertyAliasedClauses(AliasedClauses):
- """extends AliasedClauses to add support for primary/secondary joins on a relation()."""
-
- def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None, alias=None, should_adapt=True):
- self.prop = prop
- self.mapper = self.prop.mapper
- self.table = self.prop.table
- self.parentclauses = parentclauses
-
- if not alias:
- from_obj = self.mapper._with_polymorphic_selectable()
- alias = from_obj.alias()
-
- super(PropertyAliasedClauses, self).__init__(alias, equivalents=self.mapper._equivalent_columns, chain_to=parentclauses, should_adapt=should_adapt)
-
- if prop.secondary:
- self.secondary = prop.secondary.alias()
- primary_aliasizer = sql_util.ClauseAdapter(self.secondary)
- secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary))
-
- if parentclauses is not None:
- primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, equivalents=parentclauses.equivalents))
-
- self.secondaryjoin = secondary_aliasizer.traverse(secondaryjoin, clone=True)
- self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
- else:
- primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents)
- if parentclauses is not None:
- primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side, equivalents=parentclauses.equivalents))
-
- self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
- self.secondary = None
- self.secondaryjoin = None
-
- if prop.order_by:
- if prop.secondary:
- # usually this is not used but occasionally someone has a sort key in their secondary
- # table, even tho SA does not support writing this column directly
- self.order_by = secondary_aliasizer.copy_and_process(util.to_list(prop.order_by))
+ def __getattr__(self, key):
+ prop = self.__mapper._get_property(key, raiseerr=False)
+ if prop:
+ return self.__adapt_prop(prop)
+
+ for base in self.__target.__mro__:
+ try:
+ attr = object.__getattribute__(base, key)
+ except AttributeError:
+ continue
else:
- self.order_by = primary_aliasizer.copy_and_process(util.to_list(prop.order_by))
-
+ break
else:
- self.order_by = None
+ raise AttributeError(key)
-class AliasedClass(object):
- def __new__(cls, target):
- from sqlalchemy.orm import attributes
- mapper = _class_to_mapper(target)
- alias = mapper.mapped_table.alias()
- retcls = type(target.__name__ + "Alias", (cls,), {'alias':alias})
- retcls._class_state = mapper._class_state
- for prop in mapper.iterate_properties:
- existing = mapper._class_state.attrs[prop.key]
- setattr(retcls, prop.key, attributes.InstrumentedAttribute(existing.impl, comparator=AliasedComparator(alias, existing.comparator)))
-
- return retcls
+ if hasattr(attr, 'func_code'):
+ is_method = getattr(self.__target, key, None)
+ if is_method and is_method.im_self is not None:
+ return new.instancemethod(attr.im_func, self, self)
+ else:
+ return None
+ elif hasattr(attr, '__get__'):
+ return attr.__get__(None, self)
+ else:
+ return attr
- def __init__(self, alias):
- self.alias = alias
+ def __repr__(self):
+ return '<AliasedClass at 0x%x; %s>' % (
+ id(self), self.__target.__name__)
class AliasedComparator(PropComparator):
- def __init__(self, alias, comparator):
- self.alias = alias
+ def __init__(self, aliasedclass, adapter, comparator):
+ self.aliasedclass = aliasedclass
self.comparator = comparator
- self.adapter = sql_util.ClauseAdapter(alias)
+ self.adapter = adapter
+ self.__clause_element = self.adapter.traverse(self.comparator.__clause_element__())._annotate({'parententity': aliasedclass})
- def clause_element(self):
- return self.adapter.traverse(self.comparator.clause_element(), clone=True)
+ def __clause_element__(self):
+ return self.__clause_element
def operate(self, op, *other, **kwargs):
- return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs), clone=True)
+ return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs))
def reverse_operate(self, op, other, **kwargs):
- return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs), clone=True)
-
-from sqlalchemy.sql import expression
-_selectable = expression._selectable
-def _orm_selectable(selectable):
- if _is_mapped_class(selectable):
- if _is_aliased_class(selectable):
- return selectable.alias
- else:
- return _class_to_mapper(selectable)._with_polymorphic_selectable()
- else:
- return _selectable(selectable)
-expression._selectable = _orm_selectable
+ return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs))
+
+def _orm_annotate(element, exclude=None):
+ def clone(elem):
+ if exclude and elem in exclude:
+ elem = elem._clone()
+ elif '_orm_adapt' not in elem._annotations:
+ elem = elem._annotate({'_orm_adapt':True})
+ elem._copy_internals(clone=clone)
+ return elem
+
+ if element is not None:
+ element = clone(element)
+ return element
+
class _ORMJoin(expression.Join):
- """future functionality."""
__visit_name__ = expression.Join.__visit_name__
-
+
def __init__(self, left, right, onclause=None, isouter=False):
- if _is_mapped_class(left) or _is_mapped_class(right):
- if hasattr(left, '_orm_mappers'):
- left_mapper = left._orm_mappers[1]
- adapt_from = left.right
+ if hasattr(left, '_orm_mappers'):
+ left_mapper = left._orm_mappers[1]
+ adapt_from = left.right
+
+ else:
+ left_mapper, left, left_is_aliased = _entity_info(left)
+ if left_is_aliased or not left_mapper:
+ adapt_from = left
else:
- left_mapper = _class_to_mapper(left)
- if _is_aliased_class(left):
- adapt_from = left.alias
- else:
- adapt_from = None
+ adapt_from = None
- right_mapper = _class_to_mapper(right)
+ right_mapper, right, right_is_aliased = _entity_info(right)
+ if right_is_aliased:
+ adapt_to = right
+ else:
+ adapt_to = None
+
+ if left_mapper or right_mapper:
self._orm_mappers = (left_mapper, right_mapper)
-
+
if isinstance(onclause, basestring):
prop = left_mapper.get_property(onclause)
+ elif isinstance(onclause, attributes.QueryableAttribute):
+ adapt_from = onclause.__clause_element__()
+ prop = onclause.property
+ elif isinstance(onclause, MapperProperty):
+ prop = onclause
+ else:
+ prop = None
- if _is_aliased_class(right):
- adapt_to = right.alias
- else:
- adapt_to = None
-
- pj, sj, source, dest, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True)
+ if prop:
+ pj, sj, source, dest, secondary, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True)
if sj:
- left = sql.join(left, prop.secondary, onclause=pj)
+ left = sql.join(left, secondary, pj, isouter)
onclause = sj
else:
onclause = pj
+ self._target_adapter = target_adapter
+
expression.Join.__init__(self, left, right, onclause, isouter)
def join(self, right, onclause=None, isouter=False):
def outerjoin(self, right, onclause=None):
return _ORMJoin(self, right, onclause, True)
-def _join(left, right, onclause=None):
- """future functionality."""
-
- return _ORMJoin(left, right, onclause, False)
-
-def _outerjoin(left, right, onclause=None):
- """future functionality."""
+def join(left, right, onclause=None, isouter=False):
+ return _ORMJoin(left, right, onclause, isouter)
+def outerjoin(left, right, onclause=None):
return _ORMJoin(left, right, onclause, True)
-
-def has_identity(object):
- return hasattr(object, '_instance_key')
-def _state_has_identity(state):
- return '_instance_key' in state.dict
+def with_parent(instance, prop):
+ """Return criterion which selects instances with a given parent.
-def _is_mapped_class(cls):
- return hasattr(cls, '_class_state')
+ instance
+ a parent instance, which should be persistent or detached.
+
+ property
+ a class-attached descriptor, MapperProperty or string property name
+ attached to the parent instance.
+
+ \**kwargs
+ all extra keyword arguments are propagated to the constructor of
+ Query.
-def _is_aliased_class(obj):
- return isinstance(obj, type) and issubclass(obj, AliasedClass)
-
-def has_mapper(object):
- """Return True if the given object has had a mapper association
- set up, either through loading, or via insertion in a session.
"""
+ if isinstance(prop, basestring):
+ mapper = object_mapper(instance)
+ prop = mapper.get_property(prop, resolve_synonyms=True)
+ elif isinstance(prop, attributes.QueryableAttribute):
+ prop = prop.property
+
+ return prop.compare(operators.eq, instance, value_is_parent=True)
+
+
+def _entity_info(entity, entity_name=None, compile=True):
+ if isinstance(entity, AliasedClass):
+ return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
+ elif _is_mapped_class(entity):
+ if isinstance(entity, type):
+ mapper = class_mapper(entity, entity_name, compile)
+ else:
+ if compile:
+ mapper = entity.compile()
+ else:
+ mapper = entity
+ return mapper, mapper._with_polymorphic_selectable, False
+ else:
+ return None, entity, False
+
+def _entity_descriptor(entity, key):
+ if isinstance(entity, AliasedClass):
+ desc = getattr(entity, key)
+ return desc, desc.property
+ elif isinstance(entity, type):
+ desc = attributes.manager_of_class(entity)[key]
+ return desc, desc.property
+ else:
+ desc = entity.class_manager[key]
+ return desc, desc.property
+
+def _orm_columns(entity):
+ mapper, selectable, is_aliased_class = _entity_info(entity)
+ if isinstance(selectable, expression.Selectable):
+ return [c for c in selectable.c]
+ else:
+ return [selectable]
+
+def _orm_selectable(entity):
+ mapper, selectable, is_aliased_class = _entity_info(entity)
+ return selectable
- return hasattr(object, '_entity_name')
+def _is_aliased_class(entity):
+ return isinstance(entity, AliasedClass)
def _state_mapper(state, entity_name=None):
- return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)]
+ if state.entity_name is not attributes.NO_ENTITY_NAME:
+ # Override the given entity name if the object is not transient.
+ entity_name = state.entity_name
+ return state.manager.mappers[entity_name]
def object_mapper(object, entity_name=None, raiseerror=True):
"""Given an object, return the primary Mapper associated with the object instance.
be located. If False, return None.
"""
-
- try:
- mapper = object.__class__._class_state.mappers[getattr(object, '_entity_name', entity_name)]
- except (KeyError, AttributeError):
- if raiseerror:
- raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', entity_name)))
- else:
- return None
- return mapper
+ state = attributes.instance_state(object)
+ if state.entity_name is not attributes.NO_ENTITY_NAME:
+ # Override the given entity name if the object is not transient.
+ entity_name = state.entity_name
+ return class_mapper(
+ type(object), entity_name=entity_name,
+ compile=False, raiseerror=raiseerror)
def class_mapper(class_, entity_name=None, compile=True, raiseerror=True):
- """Given a class and optional entity_name, return the primary Mapper associated with the key.
+ """Given a class (or an object) and optional entity_name, return the primary Mapper associated with the key.
If no mapper can be located, raises ``InvalidRequestError``.
- """
+ """
+
+ if not isinstance(class_, type):
+ class_ = type(class_)
try:
- mapper = class_._class_state.mappers[entity_name]
+ class_manager = attributes.manager_of_class(class_)
+ mapper = class_manager.mappers[entity_name]
except (KeyError, AttributeError):
- if raiseerror:
- raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name))
- else:
- return None
+ if not raiseerror:
+ return
+ raise sa_exc.InvalidRequestError(
+ "Class '%s' entity name '%s' has no mapper associated with it" %
+ (class_.__name__, entity_name))
if compile:
- return mapper.compile()
- else:
- return mapper
+ mapper = mapper.compile()
+ return mapper
def _class_to_mapper(class_or_mapper, entity_name=None, compile=True):
- if isinstance(class_or_mapper, type):
+ if _is_aliased_class(class_or_mapper):
+ return class_or_mapper._AliasedClass__mapper
+ elif isinstance(class_or_mapper, type):
return class_mapper(class_or_mapper, entity_name=entity_name, compile=compile)
else:
if compile:
else:
return class_or_mapper
+def has_identity(object):
+ state = attributes.instance_state(object)
+ return _state_has_identity(state)
+
+def _state_has_identity(state):
+ return bool(state.key)
+
+def has_mapper(object):
+ state = attributes.instance_state(object)
+ return _state_has_mapper(state)
+
+def _state_has_mapper(state):
+ return state.entity_name is not attributes.NO_ENTITY_NAME
+
+def _is_mapped_class(cls):
+ from sqlalchemy.orm import mapperlib as mapper
+ if isinstance(cls, (AliasedClass, mapper.Mapper)):
+ return True
+
+ manager = attributes.manager_of_class(cls)
+ return manager and _INSTRUMENTOR in manager.info
+
def instance_str(instance):
"""Return a string describing an instance."""
- return instance.__class__.__name__ + "@" + hex(id(instance))
+ return state_str(attributes.instance_state(instance))
def state_str(state):
"""Return a string describing an instance."""
def attribute_str(instance, attribute):
return instance_str(instance) + "." + attribute
+def state_attribute_str(state, attribute):
+ return state_str(state) + "." + attribute
+
def identity_equal(a, b):
if a is b:
return True
- id_a = getattr(a, '_instance_key', None)
- id_b = getattr(b, '_instance_key', None)
- if id_a is None or id_b is None:
+ if a is None or b is None:
+ return False
+ try:
+ state_a = attributes.instance_state(a)
+ state_b = attributes.instance_state(b)
+ except (KeyError, AttributeError):
+ return False
+ if state_a.key is None or state_b.key is None:
return False
- return id_a == id_b
+ return state_a.key == state_b.key
+# TODO: Avoid circular import.
+attributes.identity_equal = identity_equal
+attributes._is_aliased_class = _is_aliased_class
+attributes._entity_info = _entity_info
import weakref, time
-from sqlalchemy import exceptions, logging
+from sqlalchemy import exc, log
from sqlalchemy import queue as Queue
from sqlalchemy.util import thread, threading, pickle, as_interface
"""
def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=True,
reset_on_return=True, listeners=None):
- self.logger = logging.instance_logger(self, echoflag=echo)
+ self.logger = log.instance_logger(self, echoflag=echo)
# the WeakValueDictionary works more nicely than a regular dict of
# weakrefs. the latter can pile up dead reference objects which don't
# get cleaned out. WVD adds from 1-6 method calls to a checkout
return self._connection_record.info
except AttributeError:
if self.connection is None:
- raise exceptions.InvalidRequestError("This connection is closed")
+ raise exc.InvalidRequestError("This connection is closed")
try:
return self._detached_info
except AttributeError:
"""
if self.connection is None:
- raise exceptions.InvalidRequestError("This connection is closed")
+ raise exc.InvalidRequestError("This connection is closed")
if self._connection_record is not None:
self._connection_record.invalidate(e=e)
self.connection = None
def checkout(self):
if self.connection is None:
- raise exceptions.InvalidRequestError("This connection is closed")
- self.__counter +=1
+ raise exc.InvalidRequestError("This connection is closed")
+ self.__counter += 1
if not self._pool._on_checkout or self.__counter != 1:
return self
for l in self._pool._on_checkout:
l.checkout(self.connection, self._connection_record, self)
return self
- except exceptions.DisconnectionError, e:
+ except exc.DisconnectionError, e:
if self._pool._should_log_info:
self._pool.log(
"Disconnection detected on checkout: %s" % e)
if self._pool._should_log_info:
self._pool.log("Reconnection attempts exhausted on checkout")
self.invalidate()
- raise exceptions.InvalidRequestError("This connection is closed")
+ raise exc.InvalidRequestError("This connection is closed")
def detach(self):
"""Separate this connection from its Pool.
self._connection_record = None
def close(self):
- self.__counter -=1
+ self.__counter -= 1
if self.__counter == 0:
self._close()
if not wait:
return self.do_get()
else:
- raise exceptions.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out, timeout %d" % (self.size(), self.overflow(), self._timeout))
+ raise exc.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out, timeout %d" % (self.size(), self.overflow(), self._timeout))
if self._overflow_lock is not None:
self._overflow_lock.acquire()
return "NullPool"
def do_return_conn(self, conn):
- conn.close()
+ conn.close()
def do_return_invalid(self, conn):
- pass
+ pass
def do_get(self):
return self.create_connection()
expressions. """
import re, inspect
-from sqlalchemy import types, exceptions, util, databases
+from sqlalchemy import types, exc, util, databases
from sqlalchemy.sql import expression, visitors
URL = None
"""Base class for items that define a database schema."""
__metaclass__ = expression._FigureVisitName
-
+ quote = None
+
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
-
+
for item in args:
if item is not None:
item._set_parent(self)
try:
table = metadata.tables[key]
if not useexisting and table._cant_override(*args, **kwargs):
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"Table '%s' is already defined for this MetaData instance. "
"Specify 'useexisting=True' to redefine options and "
"columns on an existing Table object." % key)
return table
except KeyError:
if mustexist:
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
"Table '%s' not defined" % (key))
try:
return type.__call__(self, name, metadata, *args, **kwargs)
Deprecated; this is an oracle-only argument - "schema" should
be used in its place.
- quote
- When True, indicates that the Table identifier must be quoted.
- This flag does *not* disable quoting; for case-insensitive names,
- use an all lower case identifier.
+ quote
+ Force quoting of the identifier on or off, based on `True` or
+ `False`. Defaults to `None`. This flag is rarely needed,
+ as quoting is normally applied
+ automatically for known reserved words, as well as for
+ "case sensitive" identifiers. An identifier is "case sensitive"
+ if it contains non-lowercase letters, otherwise it's
+ considered to be "case insensitive".
quote_schema
- When True, indicates that the schema identifier must be quoted.
- This flag does *not* disable quoting; for case-insensitive names,
- use an all lower case identifier.
+ same as 'quote' but applies to the schema identifier.
+
"""
-
super(Table, self).__init__(name)
self.metadata = metadata
self.schema = kwargs.pop('schema', kwargs.pop('owner', None))
self._set_parent(metadata)
- self.__extra_kwargs(**kwargs)
+ self.__extra_kwargs(**kwargs)
# load column definitions from the database if 'autoload' is defined
# we do it after the table is in the singleton dictionary to support
autoload_with = kwargs.pop('autoload_with', None)
schema = kwargs.pop('schema', None)
if schema and schema != self.schema:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Can't change schema of existing table from '%s' to '%s'",
(self.schema, schema))
['autoload', 'autoload_with', 'schema', 'owner']))
def __extra_kwargs(self, **kwargs):
- self.quote = kwargs.pop('quote', False)
- self.quote_schema = kwargs.pop('quote_schema', False)
+ self.quote = kwargs.pop('quote', None)
+ self.quote_schema = kwargs.pop('quote_schema', None)
if kwargs.get('info'):
self._info = kwargs.pop('info')
or subtype of Integer.
quote
- When True, indicates that the Column identifier must be quoted.
- This flag does *not* disable quoting; for case-insensitive names,
- use an all lower case identifier.
+ Force quoting of the identifier on or off, based on `True` or
+ `False`. Defaults to `None`. This flag is rarely needed,
+ as quoting is normally applied
+ automatically for known reserved words, as well as for
+ "case sensitive" identifiers. An identifier is "case sensitive"
+ if it contains non-lowercase letters, otherwise it's
+ considered to be "case insensitive".
"""
name = kwargs.pop('name', None)
args = list(args)
if isinstance(args[0], basestring):
if name is not None:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"May not pass name positionally and as a keyword.")
name = args.pop(0)
if args:
(isinstance(args[0], type) and
issubclass(args[0], types.AbstractType))):
if type_ is not None:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"May not pass type_ positionally and as a keyword.")
type_ = args.pop(0)
self.default = kwargs.pop('default', None)
self.index = kwargs.pop('index', None)
self.unique = kwargs.pop('unique', None)
- self.quote = kwargs.pop('quote', False)
+ self.quote = kwargs.pop('quote', None)
self.onupdate = kwargs.pop('onupdate', None)
self.autoincrement = kwargs.pop('autoincrement', True)
self.constraints = util.Set()
self.foreign_keys = util.OrderedSet()
+ util.set_creation_order(self)
+
if kwargs.get('info'):
self._info = kwargs.pop('info')
if kwargs:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Unknown arguments passed to Column: " + repr(kwargs.keys()))
def __str__(self):
bind = property(bind)
def references(self, column):
- """Return True if this references the given column via a foreign key."""
+ """Return True if this Column references the given column via foreign key."""
for fk in self.foreign_keys:
if fk.references(column.table):
return True
def _set_parent(self, table):
if self.name is None:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Column must be constructed with a name or assign .name "
"before adding to a Table.")
if self.key is None:
self.key = self.name
self.metadata = table.metadata
if getattr(self, 'table', None) is not None:
- raise exceptions.ArgumentError("this Column already has a table!")
+ raise exc.ArgumentError("this Column already has a table!")
if not self._is_oid:
self._pre_existing_column = table._columns.get(self.key)
if self.primary_key:
table.primary_key.replace(self)
elif self.key in table.primary_key:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Trying to redefine primary-key column '%s' as a "
"non-primary-key column on table '%s'" % (
self.key, table.fullname))
if self.index:
if isinstance(self.index, basestring):
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"The 'index' keyword argument on Column is boolean only. "
"To create indexes with a specific name, create an "
"explicit Index object external to the Table.")
Index('ix_%s' % self._label, self, unique=self.unique)
elif self.unique:
if isinstance(self.unique, basestring):
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"The 'unique' keyword argument on Column is boolean only. "
"To create unique constraints or indexes with a specific "
"name, append an explicit UniqueConstraint to the Table's "
"""Create a copy of this ``Column``, unitialized.
This is used in ``Table.tometadata``.
- """
+ """
return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, index=self.index, autoincrement=self.autoincrement, *[c.copy() for c in self.constraints])
-
- def _make_proxy(self, selectable, name = None):
+
+ def _make_proxy(self, selectable, name=None):
"""Create a *proxy* for this column.
This is a copy of this ``Column`` referenced by a different parent
(such as an alias or select statement).
- """
+ """
fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk)
c.table = selectable
[c._init_items(f) for f in fk]
return c
-
def get_children(self, schema_visitor=False, **kwargs):
if schema_visitor:
return [x for x in (self.default, self.onupdate) if x is not None] + \
For a composite (multiple column) FOREIGN KEY, use a ForeignKeyConstraint
within the Table definition.
- """
+ """
def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None):
"""Construct a column-level FOREIGN KEY.
def references(self, table):
"""Return True if the given table is referenced by this ForeignKey."""
-
return table.corresponding_column(self.column) is not None
def get_referent(self, table):
"""Return the column in the given table referenced by this ForeignKey.
Returns None if this ``ForeignKey`` does not reference the given table.
+
"""
+
return table.corresponding_column(self.column)
def column(self):
parenttable = c.table
break
else:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Parent column '%s' does not descend from a "
"table-attached Column" % str(self.parent))
m = re.match(r"^(.+?)(?:\.(.+?))?(?:\.(.+?))?$", self._colspec,
re.UNICODE)
if m is None:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Invalid foreign key column specification: %s" %
self._colspec)
if m.group(3) is None:
(tname, colname) = m.group(1, 2)
schema = None
else:
- (schema,tname,colname) = m.group(1,2,3)
+ (schema, tname, colname) = m.group(1, 2, 3)
if _get_table_key(tname, schema) not in parenttable.metadata:
- raise exceptions.NoReferencedTableError(
+ raise exc.NoReferencedTableError(
"Could not find table '%s' with which to generate a "
"foreign key" % tname)
table = Table(tname, parenttable.metadata,
else:
self._column = table.c[colname]
except KeyError, e:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Could not create ForeignKey '%s' on table '%s': "
"table '%s' has no column named '%s'" % (
self._colspec, parenttable.name, table.name, str(e)))
-
- elif isinstance(self._colspec, expression.Operators):
- self._column = self._colspec.clause_element()
+
+ elif hasattr(self._colspec, '__clause_element__'):
+ self._column = self._colspec.__clause_element__()
else:
self._column = self._colspec
defaulted = argspec[3] is not None and len(argspec[3]) or 0
if positionals - defaulted > 1:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"ColumnDefault Python function takes zero or one "
"positional arguments")
return fn
-
def _visit_name(self):
if self.for_update:
return "column_onupdate"
"""Represents a named database sequence."""
def __init__(self, name, start=None, increment=None, schema=None,
- optional=False, quote=False, **kwargs):
+ optional=False, quote=None, **kwargs):
super(Sequence, self).__init__(**kwargs)
self.name = name
self.start = start
self.increment = increment
- self.optional=optional
+ self.optional = optional
self.quote = quote
self.schema = schema
self.kwargs = kwargs
bind = _bind_or_error(self)
bind.drop(self, checkfirst=checkfirst)
-
class Constraint(SchemaItem):
"""A table-level SQL constraint, such as a KEY.
self.initially = initially
def __contains__(self, x):
- return self.columns.contains_column(x)
-
+ return x in self.columns
+
+ def contains_column(self, col):
+ return self.columns.contains_column(col)
+
def keys(self):
return self.columns.keys()
self.onupdate = onupdate
self.ondelete = ondelete
if self.name is None and use_alter:
- raise exceptions.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
+ raise exc.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
self.use_alter = use_alter
def _set_parent(self, table):
if self not in table.constraints:
table.constraints.add(self)
for (c, r) in zip(self.__colnames, self.__refcolnames):
- self.append_element(c,r)
+ self.append_element(c, r)
def append_element(self, col, refcol):
fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
deferrable=kwargs.pop('deferrable', None),
initially=kwargs.pop('initially', None))
if kwargs:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
'Unknown PrimaryKeyConstraint argument(s): %s' %
', '.join([repr(x) for x in kwargs.keys()]))
def add(self, col):
self.columns.add(col)
- col.primary_key=True
+ col.primary_key = True
append_column = add
def replace(self, col):
self.columns.replace(col)
def remove(self, col):
- col.primary_key=False
+ col.primary_key = False
del self.columns[col.key]
def copy(self):
deferrable=kwargs.pop('deferrable', None),
initially=kwargs.pop('initially', None))
if kwargs:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
'Unknown UniqueConstraint argument(s): %s' %
', '.join([repr(x) for x in kwargs.keys()]))
self._set_parent(column.table)
elif column.table != self.table:
# all columns muse be from same table
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"All index columns must be from same table. "
"%s is from %s not %s" % (column, column.table, self.table))
elif column.name in [ c.name for c in self.columns ]:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"A column may not appear twice in the "
"same index (%s already has column %s)" % (self.name, column))
self.columns.append(column)
self.ddl_listeners = util.defaultdict(list)
if reflect:
if not bind:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"A bind must be supplied in conjunction with reflect=True")
self.reflect()
missing = [name for name in only if name not in available]
if missing:
s = schema and (" schema '%s'" % schema) or ''
- raise exceptions.InvalidRequestError(
+ raise exc.InvalidRequestError(
'Could not reflect: requested table(s) not available '
'in %s%s: (%s)' % (bind.engine.url, s, ', '.join(missing)))
load = [name for name in only if name not in current]
"""
if not isinstance(statement, basestring):
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Expected a string or unicode SQL statement, got '%r'" %
statement)
if (on is not None and
(not isinstance(on, basestring) and not callable(on))):
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Expected the name of a database dialect or a callable for "
"'on' criteria, got type '%s'." % type(on).__name__)
"""
if not hasattr(schema_item, 'ddl_listeners'):
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"%s does not support DDL events" % type(schema_item).__name__)
if event not in schema_item.ddl_events:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Unknown event, expected one of (%s), got '%r'" %
(', '.join(schema_item.ddl_events), event))
schema_item.ddl_listeners[event].append(self)
'Execution can not proceed without a database to execute '
'against. Either execute with an explicit connection or '
'assign %s to enable implicit execution.') % (item, bindable)
- raise exceptions.UnboundExecutionError(msg)
+ raise exc.UnboundExecutionError(msg)
return bind
from sqlalchemy.sql.expression import *
-from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor
+from sqlalchemy.sql.visitors import ClauseVisitor
"""
import string, re, itertools
-from sqlalchemy import schema, engine, util, exceptions
+from sqlalchemy import schema, engine, util, exc
from sqlalchemy.sql import operators, functions
from sqlalchemy.sql import expression as sql
paradigm as visitors.ClauseVisitor but implements its own traversal.
"""
- __traverse_options__ = {'column_collections':False, 'entry':True}
-
operators = OPERATORS
functions = FUNCTIONS
# for aliases
self.generated_ids = {}
- # paramstyle from the dialect (comes from DB-API)
- self.paramstyle = self.dialect.paramstyle
-
# true if the paramstyle is positional
self.positional = self.dialect.positional
+ if self.positional:
+ self.positiontup = []
- self.bindtemplate = BIND_TEMPLATES[self.paramstyle]
-
- # a list of the compiled's bind parameter names, used to help
- # formulate a positional argument list
- self.positiontup = []
+ self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle]
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = self.dialect.identifier_preparer
return ""
def visit_grouping(self, grouping, **kwargs):
- return "(" + self.process(grouping.elem) + ")"
+ return "(" + self.process(grouping.element) + ")"
- def visit_label(self, label, result_map=None):
+ def visit_label(self, label, result_map=None, render_labels=False):
+ if not render_labels:
+ return self.process(label.element)
+
labelname = self._truncated_identifier("colident", label.name)
if result_map is not None:
- result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type)
+ result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
- return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
+ return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
def visit_column(self, column, result_map=None, **kwargs):
if getattr(column, "is_literal", False):
name = self.escape_literal_column(name)
else:
- name = self.preparer.quote(column, name)
+ name = self.preparer.quote(name, column.quote)
if column.table is None or not column.table.named_with_column:
return name
else:
if getattr(column.table, 'schema', None):
- schema_prefix = self.preparer.quote(column.table, column.table.schema) + '.'
+ schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.'
else:
schema_prefix = ''
- return schema_prefix + self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name
+ return schema_prefix + self.preparer.quote(ANONYMOUS_LABEL.sub(self._process_anon, column.table.name), column.table.quote) + "." + name
def escape_literal_column(self, text):
"""provide escaping for the literal_column() construct."""
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam and (existing.unique or bindparam.unique):
- raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
+ raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
return truncname
def _process_anon(self, match):
- (ident, derived) = match.group(1,2)
+ (ident, derived) = match.group(1, 2)
key = ('anonymous', ident)
if key in self.generated_ids:
def bindparam_string(self, name):
if self.positional:
self.positiontup.append(name)
-
- return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
+ return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
+ else:
+ return self.bindtemplate % {'name':name}
def visit_alias(self, alias, asfrom=False, **kwargs):
if asfrom:
froms = select._get_display_froms(existingfroms)
- correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms])))
+ correlate_froms = util.Set(sql._from_objects(*froms))
# TODO: might want to propigate existing froms for select(select(select))
# where innermost select should correlate to outermost
[c for c in [
self.process(
self.label_select_column(select, co, asfrom=asfrom),
+ render_labels=True,
**column_clause_args)
for co in select.inner_columns
]
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
if getattr(table, "schema", None):
- return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name)
+ return self.preparer.quote(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote)
else:
- return self.preparer.quote(table, table.name)
+ return self.preparer.quote(table.name, table.quote)
else:
return ""
return (insert + " INTO %s (%s) VALUES (%s)" %
(preparer.format_table(insert_stmt.table),
- ', '.join([preparer.quote(c[0], c[0].name)
+ ', '.join([preparer.quote(c[0].name, c[0].quote)
for c in colparams]),
', '.join([c[1] for c in colparams])))
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0].name, c[0].quote), c[1]) for c in colparams], ', ')
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
if constraint.name is not None:
self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint]))
+ self.append("(%s)" % ', '.join([self.preparer.quote(c.name, c.quote) for c in constraint]))
self.define_constraint_deferrability(constraint)
def visit_foreign_key_constraint(self, constraint):
preparer.format_constraint(constraint))
table = list(constraint.elements)[0].column.table
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]),
+ ', '.join([preparer.quote(f.parent.name, f.parent.quote) for f in constraint.elements]),
preparer.format_table(table),
- ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements])
+ ', '.join([preparer.quote(f.column.name, f.column.quote) for f in constraint.elements])
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
if constraint.name is not None:
self.append("CONSTRAINT %s " %
self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint])))
+ self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c.name, c.quote) for c in constraint])))
self.define_constraint_deferrability(constraint)
def define_constraint_deferrability(self, constraint):
self.append("INDEX %s ON %s (%s)" \
% (preparer.format_index(index),
preparer.format_table(index.table),
- string.join([preparer.quote(c, c.name) for c in index.columns], ', ')))
+ string.join([preparer.quote(c.name, c.quote) for c in index.columns], ', ')))
self.execute()
or not self.legal_characters.match(unicode(value))
or (lc_value != value))
- def quote(self, obj, ident):
- if getattr(obj, 'quote', False):
+ def quote(self, ident, force):
+ if force:
return self.quote_identifier(ident)
+ elif force is False:
+ return ident
+
if ident in self.__strings:
return self.__strings[ident]
else:
self.__strings[ident] = ident
return self.__strings[ident]
- def should_quote(self, object):
- return object.quote or self._requires_quotes(object.name)
-
def format_sequence(self, sequence, use_schema=True):
- name = self.quote(sequence, sequence.name)
+ name = self.quote(sequence.name, sequence.quote)
if not self.omit_schema and use_schema and sequence.schema is not None:
- name = self.quote(sequence, sequence.schema) + "." + name
+ name = self.quote(sequence.schema, sequence.quote) + "." + name
return name
def format_label(self, label, name=None):
- return self.quote(label, name or label.name)
+ return self.quote(name or label.name, label.quote)
def format_alias(self, alias, name=None):
- return self.quote(alias, name or alias.name)
+ return self.quote(name or alias.name, alias.quote)
def format_savepoint(self, savepoint, name=None):
- return self.quote(savepoint, name or savepoint.ident)
+ return self.quote(name or savepoint.ident, savepoint.quote)
def format_constraint(self, constraint):
- return self.quote(constraint, constraint.name)
+ return self.quote(constraint.name, constraint.quote)
def format_index(self, index):
- return self.quote(index, index.name)
+ return self.quote(index.name, index.quote)
def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
if name is None:
name = table.name
- result = self.quote(table, name)
+ result = self.quote(name, table.quote)
if not self.omit_schema and use_schema and getattr(table, "schema", None):
- result = self.quote(table, table.schema) + "." + result
+ result = self.quote(table.schema, table.quote_schema) + "." + result
return result
def format_column(self, column, use_table=False, name=None, table_name=None):
- """Prepare a quoted column name.
-
- deprecated. use preparer.quote(col, column.name) or combine with format_table()
- """
+ """Prepare a quoted column name."""
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name)
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote)
else:
- return self.quote(column, name)
+ return self.quote(name, column.quote)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
# a longer sequence.
if not self.omit_schema and use_schema and getattr(table, 'schema', None):
- return (self.quote_identifier(table.schema),
+ return (self.quote(table.schema, table.quote_schema),
self.format_table(table, use_schema=False))
else:
return (self.format_table(table, use_schema=False), )
"""
import itertools, re
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
functions, schema, sql_util = None, None, None
-DefaultDialect, ClauseAdapter = None, None
+DefaultDialect, ClauseAdapter, Annotated = None, None, None
__all__ = [
'Alias', 'ClauseElement',
def exists(*args, **kwargs):
"""Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object.
+
+ Calling styles are of the following forms::
+
+ # use on an existing select()
+ s = select([<columns>]).where(<criterion>)
+ s = exists(s)
+
+ # construct a select() at once
+ exists(['*'], **select_arguments).where(<criterion>)
+
+ # columns argument is optional, generates "EXISTS (SELECT *)"
+ # by default.
+ exists().where(<criterion>)
- The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by
- itself or used as a subquery within an enclosing select.
-
- \*args, \**kwargs
- all arguments are sent directly to the [sqlalchemy.sql.expression#select()]
- function to produce a ``SELECT`` statement.
"""
-
return _Exists(*args, **kwargs)
def union(*selects, **kwargs):
return CompoundSelect(keyword, *selects, **kwargs)
def _is_literal(element):
- return not isinstance(element, ClauseElement)
+ return not isinstance(element, (ClauseElement, Operators))
+
+def _from_objects(*elements, **kwargs):
+ return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
+def _labeled(element):
+ if not hasattr(element, 'name'):
+ return element.label(None)
+ else:
+ return element
+
def _literal_as_text(element):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
return _TextClause(unicode(element))
else:
return element
def _literal_as_column(element):
- if isinstance(element, Operators):
- return element.clause_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
return literal_column(str(element))
else:
return element
def _literal_as_binds(element, name=None, type_=None):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
if element is None:
return null()
return element
def _no_literals(element):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
- raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+ raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
else:
return element
def _corresponding_column_or_error(fromclause, column, require_embedded=False):
c = fromclause.corresponding_column(column, require_embedded=require_embedded)
if not c:
- raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+ raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
return c
def _selectable(element):
elif isinstance(element, Selectable):
return element
else:
- raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
+ raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
-
def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
return isinstance(col, ColumnElement)
class ClauseElement(object):
"""Base class for elements of a programmatically constructed SQL expression."""
__metaclass__ = _FigureVisitName
-
+ _annotations = {}
+ supports_execution = False
+
def _clone(self):
"""Create a shallow copy of this ClauseElement.
"""
raise NotImplementedError(repr(self))
+
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with the given annotations dictionary."""
+
+ global Annotated
+ if Annotated is None:
+ from sqlalchemy.sql.util import Annotated
+ return Annotated(self, values)
def unique_params(self, *optionaldict, **kwargs):
"""Return a copy with ``bindparam()`` elments replaced.
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
- raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument")
+ raise exc.ArgumentError("params() takes zero or one positional dictionary argument")
def visit_bindparam(bind):
if bind.key in kwargs:
bind.value = kwargs[bind.key]
if unique:
bind._convert_to_unique()
- return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True)
+ return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam})
def compare(self, other):
"""Compare this ClauseElement to the given ClauseElement.
def self_group(self, against=None):
return self
- def supports_execution(self):
- """Return True if this clause element represents a complete executable statement."""
-
- return False
-
def bind(self):
"""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
return self._bind
except AttributeError:
pass
- for f in self._get_from_objects():
+ for f in _from_objects(self):
if f is self:
continue
engine = f.bind
'Engine for execution. Or, assign a bind to the statement '
'or the Metadata of its underlying tables to enable '
'implicit execution via this method.' % label)
- raise exceptions.UnboundExecutionError(msg)
+ raise exc.UnboundExecutionError(msg)
return e.execute_clauseelement(self, multiparams, params)
def scalar(self, *multiparams, **params):
self.__module__, self.__class__.__name__, id(self), friendly)
+class _Immutable(object):
+ """mark a ClauseElement as 'immutable' when expressions are cloned."""
+
+ def _clone(self):
+ return self
+
class Operators(object):
def __and__(self, other):
return self.operate(operators.and_, other)
return self.operate(operators.op, opstring, b)
return op
- def clause_element(self):
- raise NotImplementedError()
-
def operate(self, op, *other, **kwargs):
raise NotImplementedError()
def ilike(self, other, escape=None):
return self.operate(operators.ilike_op, other, escape=escape)
- def in_(self, *other):
+ def in_(self, other):
return self.operate(operators.in_op, other)
def startswith(self, other, **kwargs):
def __compare(self, op, obj, negate=None, reverse=False, **kwargs):
if obj is None or isinstance(obj, _Null):
if op == operators.eq:
- return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot)
+ return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot)
elif op == operators.ne:
- return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_)
+ return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_)
else:
- raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+ raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL")
else:
obj = self._check_literal(obj)
if reverse:
- return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+ return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
else:
- return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+ return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
def __operate(self, op, obj, reverse=False):
obj = self._check_literal(obj)
type_ = self._compare_type(obj)
if reverse:
- return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_)
+ return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_)
else:
- return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_)
+ return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_)
# a mapping of operators with the method they use, along with their negated
# operator for comparison operators
o = _CompareMixin.operators[op]
return o[0](self, op, other, reverse=True, *o[1:], **kwargs)
- def in_(self, *other):
- return self._in_impl(operators.in_op, operators.notin_op, *other)
-
- def _in_impl(self, op, negate_op, *other):
- # Handle old style *args argument passing
- if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)):
- util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable')
- seq_or_selectable = other
- else:
- seq_or_selectable = other[0]
+ def in_(self, other):
+ return self._in_impl(operators.in_op, operators.notin_op, other)
+ def _in_impl(self, op, negate_op, seq_or_selectable):
if isinstance(seq_or_selectable, Selectable):
return self.__compare( op, seq_or_selectable, negate=negate_op)
for o in seq_or_selectable:
if not _is_literal(o):
if not isinstance( o, _CompareMixin):
- raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
+ raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
else:
o = self._bind_param(o)
args.append(o)
if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
other.type = self.type
return other
- elif isinstance(other, Operators):
- return other.expression_element()
+ elif hasattr(other, '__clause_element__'):
+ return other.__clause_element__()
elif _is_literal(other):
return self._bind_param(other)
else:
return other
- def clause_element(self):
- """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``."""
- return self
-
- def expression_element(self):
- """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
-
- return self
-
def _compare_type(self, obj):
"""Allow subclasses to override the type used in constructing
``_BinaryExpression`` objects.
primary_key = False
foreign_keys = []
-
+ quote = None
+
def base_columns(self):
- if hasattr(self, '_base_columns'):
- return self._base_columns
- self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
+ if not hasattr(self, '_base_columns'):
+ self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
return self._base_columns
base_columns = property(base_columns)
def proxy_set(self):
- if hasattr(self, '_proxy_set'):
- return self._proxy_set
- s = util.Set([self])
- if hasattr(self, 'proxies'):
- for c in self.proxies:
- s = s.union(c.proxy_set)
- self._proxy_set = s
- return s
+ if not hasattr(self, '_proxy_set'):
+ s = util.Set([self])
+ if hasattr(self, 'proxies'):
+ for c in self.proxies:
+ s.update(c.proxy_set)
+ self._proxy_set = s
+ return self._proxy_set
proxy_set = property(proxy_set)
def shares_lineage(self, othercolumn):
co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None))
co.proxies = [self]
- selectable.columns[name]= co
+ selectable.columns[name] = co
return co
def anon_label(self):
def __contains__(self, other):
if not isinstance(other, basestring):
- raise exceptions.ArgumentError("__contains__ requires a string argument")
+ raise exc.ArgumentError("__contains__ requires a string argument")
return util.OrderedProperties.__contains__(self, other)
def contains_column(self, col):
l.append(c==local)
return and_(*l)
+ def __hash__(self):
+ return hash(tuple(self._list))
+
class Selectable(ClauseElement):
"""mark a class as being selectable"""
"""Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
__visit_name__ = 'fromclause'
- named_with_column=False
+ named_with_column = False
_hide_froms = []
+ quote = None
def _get_from_objects(self, **modifiers):
return []
return fromclause in util.Set(self._cloned_set)
def replace_selectable(self, old, alias):
- """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
+ """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
- global ClauseAdapter
- if ClauseAdapter is None:
- from sqlalchemy.sql.util import ClauseAdapter
- return ClauseAdapter(alias).traverse(self, clone=True)
+ global ClauseAdapter
+ if ClauseAdapter is None:
+ from sqlalchemy.sql.util import ClauseAdapter
+ return ClauseAdapter(alias).traverse(self)
def correspond_on_equivalents(self, column, equivalents):
col = self.corresponding_column(column, require_embedded=True)
def _convert_to_unique(self):
if not self.unique:
- self.unique=True
+ self.unique = True
self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param')
def _get_from_objects(self, **modifiers):
__visit_name__ = 'textclause'
_bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+ supports_execution = True
_hide_froms = []
oid_column = None
def _get_from_objects(self, **modifiers):
return []
- def supports_execution(self):
- return True
-
- def _table_iterator(self):
- return iter([])
-
class _Null(ColumnElement):
"""Represent the NULL keyword in a SQL statement.
__visit_name__ = 'calculatedclause'
def __init__(self, name, *clauses, **kwargs):
+ ColumnElement.__init__(self)
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type_', None))
self._bind = kwargs.get('bind', None)
def clauses(self):
if isinstance(self.clause_expr, _Grouping):
- return self.clause_expr.elem
+ return self.clause_expr.element
else:
return self.clause_expr
clauses = property(clauses)
__visit_name__ = _UnaryExpression.__visit_name__
def __init__(self, *args, **kwargs):
- kwargs['correlate'] = True
- s = select(*args, **kwargs).as_scalar().self_group()
+ if args and isinstance(args[0], _SelectBaseMixin):
+ s = args[0]
+ else:
+ if not args:
+ args = ([literal_column('*')],)
+ s = select(*args, **kwargs).as_scalar().self_group()
+
_UnaryExpression.__init__(self, s, operator=operators.exists)
def select(self, whereclause=None, **params):
self.right = _selectable(right).self_group()
if onclause is None:
- self.onclause = self.__match_primaries(self.left, self.right)
+ self.onclause = self._match_primaries(self.left, self.right)
else:
self.onclause = onclause
def get_children(self, **kwargs):
return self.left, self.right, self.onclause
- def __match_primaries(self, primary, secondary):
+ def _match_primaries(self, primary, secondary):
global sql_util
if not sql_util:
from sqlalchemy.sql import util as sql_util
return self.select(use_labels=True, correlate=False).alias(name)
def _hide_froms(self):
- return itertools.chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set])
+ return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set])
_hide_froms = property(_hide_froms)
def _get_from_objects(self, **modifiers):
def __init__(self, selectable, alias=None):
baseselectable = selectable
while isinstance(baseselectable, Alias):
- baseselectable = baseselectable.selectable
+ baseselectable = baseselectable.element
self.original = baseselectable
- self.selectable = selectable
+ self.supports_execution = baseselectable.supports_execution
+ self.element = selectable
if alias is None:
if self.original.named_with_column:
alias = getattr(self.original, 'name', None)
def is_derived_from(self, fromclause):
if fromclause in util.Set(self._cloned_set):
return True
- return self.selectable.is_derived_from(fromclause)
-
- def supports_execution(self):
- return self.original.supports_execution()
-
- def _table_iterator(self):
- return self.original._table_iterator()
+ return self.element.is_derived_from(fromclause)
def _populate_column_collection(self):
- for col in self.selectable.columns:
+ for col in self.element.columns:
col._make_proxy(self)
- if self.selectable.oid_column is not None:
- self._oid_column = self.selectable.oid_column._make_proxy(self)
+ if self.element.oid_column is not None:
+ self._oid_column = self.element.oid_column._make_proxy(self)
def _copy_internals(self, clone=_clone):
- self._reset_exported()
- self.selectable = _clone(self.selectable)
- baseselectable = self.selectable
- while isinstance(baseselectable, Alias):
- baseselectable = baseselectable.selectable
- self.original = baseselectable
+ self._reset_exported()
+ self.element = _clone(self.element)
+ baseselectable = self.element
+ while isinstance(baseselectable, Alias):
+ baseselectable = baseselectable.selectable
+ self.original = baseselectable
def get_children(self, column_collections=True, aliased_selectables=True, **kwargs):
if column_collections:
for c in self.c:
yield c
if aliased_selectables:
- yield self.selectable
+ yield self.element
def _get_from_objects(self, **modifiers):
return [self]
def bind(self):
- return self.selectable.bind
+ return self.element.bind
bind = property(bind)
-class _ColumnElementAdapter(ColumnElement):
- """Adapts a ClauseElement which may or may not be a
- ColumnElement subclass itself into an object which
- acts like a ColumnElement.
- """
+class _Grouping(ColumnElement):
+ """Represent a grouping within a column expression"""
- def __init__(self, elem):
- self.elem = elem
- self.type = getattr(elem, 'type', None)
+ def __init__(self, element):
+ ColumnElement.__init__(self)
+ self.element = element
+ self.type = getattr(element, 'type', None)
def key(self):
- return self.elem.key
+ return self.element.key
key = property(key)
def _label(self):
try:
- return self.elem._label
+ return self.element._label
except AttributeError:
return self.anon_label
_label = property(_label)
def _copy_internals(self, clone=_clone):
- self.elem = clone(self.elem)
+ self.element = clone(self.element)
def get_children(self, **kwargs):
- return self.elem,
+ return self.element,
def _get_from_objects(self, **modifiers):
- return self.elem._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def __getattr__(self, attr):
- return getattr(self.elem, attr)
+ return getattr(self.element, attr)
def __getstate__(self):
- return {'elem':self.elem, 'type':self.type}
+ return {'element':self.element, 'type':self.type}
def __setstate__(self, state):
- self.elem = state['elem']
+ self.element = state['element']
self.type = state['type']
-class _Grouping(_ColumnElementAdapter):
- """Represent a grouping within a column expression"""
- pass
-
class _FromGrouping(FromClause):
"""Represent a grouping of a FROM clause"""
__visit_name__ = 'grouping'
- def __init__(self, elem):
- self.elem = elem
+ def __init__(self, element):
+ self.element = element
def columns(self):
- return self.elem.columns
+ return self.element.columns
columns = c = property(columns)
def _hide_froms(self):
- return self.elem._hide_froms
+ return self.element._hide_froms
_hide_froms = property(_hide_froms)
def get_children(self, **kwargs):
- return self.elem,
+ return self.element,
def _copy_internals(self, clone=_clone):
- self.elem = clone(self.elem)
+ self.element = clone(self.element)
def _get_from_objects(self, **modifiers):
- return self.elem._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def __getattr__(self, attr):
- return getattr(self.elem, attr)
+ return getattr(self.element, attr)
class _Label(ColumnElement):
"""Represents a column label (AS).
``ColumnElement`` subclasses.
"""
- def __init__(self, name, obj, type_=None):
- while isinstance(obj, _Label):
- obj = obj.obj
- self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
- self.obj = obj.self_group(against=operators.as_)
- self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
+ def __init__(self, name, element, type_=None):
+ while isinstance(element, _Label):
+ element = element.element
+ self.name = name or "{ANON %d %s}" % (id(self), getattr(element, 'name', 'anon'))
+ self.element = element.self_group(against=operators.as_)
+ self.type = sqltypes.to_instance(type_ or getattr(element, 'type', None))
def key(self):
return self.name
_label = property(_label)
def _proxy_attr(name):
+ get = util.attrgetter(name)
def attr(self):
- return getattr(self.obj, name)
+ return get(self.element)
return property(attr)
proxies = _proxy_attr('proxies')
primary_key = _proxy_attr('primary_key')
foreign_keys = _proxy_attr('foreign_keys')
- def expression_element(self):
- return self.obj
-
def get_children(self, **kwargs):
- return self.obj,
+ return self.element,
def _copy_internals(self, clone=_clone):
- self.obj = clone(self.obj)
+ self.element = clone(self.element)
def _get_from_objects(self, **modifiers):
- return self.obj._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def _make_proxy(self, selectable, name = None):
- if isinstance(self.obj, (Selectable, ColumnElement)):
- e = self.obj._make_proxy(selectable, name=self.name)
+ if isinstance(self.element, (Selectable, ColumnElement)):
+ e = self.element._make_proxy(selectable, name=self.name)
else:
e = column(self.name)._make_proxy(selectable=selectable)
e.proxies.append(self)
return e
-class _ColumnClause(ColumnElement):
+class _ColumnClause(_Immutable, ColumnElement):
"""Represents a generic column expression from any textual string.
This includes columns associated with tables, aliases and select
return self.name.encode('ascii', 'backslashreplace')
description = property(description)
- def _clone(self):
- # ColumnClause is immutable
- return self
-
def _label(self):
- """Generate a 'label' string for this column.
- """
-
- # for a "literal" column, we've no idea what the text is
- # therefore no 'label' can be automatically generated
if self.is_literal:
return None
if not self.__label:
counter = 1
while label in self.table.c:
label = self.__label + "_" + str(counter)
- counter +=1
+ counter += 1
self.__label = label
else:
self.__label = self.name
return self.__label
-
_label = property(_label)
def label(self, name):
- # if going off the "__label" property and its None, we have
- # no label; return self
if name is None:
return self
else:
return super(_ColumnClause, self).label(name)
def _get_from_objects(self, **modifiers):
- if self.table is not None:
+ if self.table:
return [self.table]
else:
return []
def _bind_param(self, obj):
return _BindParamClause(self.name, obj, type_=self.type, unique=True)
- def _make_proxy(self, selectable, name = None):
+ def _make_proxy(self, selectable, name=None, attach=True):
# propigate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
c.proxies = [self]
- if not self._is_oid:
+ if attach and not self._is_oid:
selectable.columns[c.name] = c
return c
def _compare_type(self, obj):
return self.type
-class TableClause(FromClause):
+class TableClause(_Immutable, FromClause):
"""Represents a "table" construct.
Note that this represents tables only as another syntactical
return self.name.encode('ascii', 'backslashreplace')
description = property(description)
- def _clone(self):
- # TableClause is immutable
- return self
-
def append_column(self, c):
self._columns[c.name] = c
c.table = self
def _get_from_objects(self, **modifiers):
return [self]
-
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
+ supports_execution = True
+
def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, autocommit=False):
self.use_labels = use_labels
self.for_update = for_update
"""
return self.as_scalar().label(name)
- def supports_execution(self):
- """part of the ClauseElement contract; returns ``True`` in all cases for this class."""
-
- return True
-
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to True."""
class _ScalarSelect(_Grouping):
__visit_name__ = 'grouping'
- def __init__(self, elem):
- self.elem = elem
- cols = list(elem.inner_columns)
+ def __init__(self, element):
+ self.element = element
+ cols = list(element.inner_columns)
if len(cols) != 1:
- raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
+ raise exc.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
self.type = cols[0].type
def columns(self):
- raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
+ raise exc.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
columns = c = property(columns)
def self_group(self, **kwargs):
if not numcols:
numcols = len(s.c)
elif len(s.c) != numcols:
- raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
+ raise exc.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
(1, len(self.selects[0].c), n+1, len(s.c))
)
if s._order_by_clause:
return (column_collections and list(self.c) or []) + \
[self._order_by_clause, self._group_by_clause] + list(self.selects)
- def _table_iterator(self):
- for s in self.selects:
- for t in s._table_iterator():
- yield t
-
def bind(self):
if self._bind:
return self._bind
self._distinct = distinct
self._correlate = util.Set()
+ self._froms = util.OrderedSet()
if columns:
self._raw_columns = [
for c in
[_literal_as_column(c) for c in columns]
]
+
+ self._froms.update(_from_objects(*self._raw_columns))
else:
self._raw_columns = []
-
- if from_obj:
- self._froms = util.Set([
- _is_literal(f) and _TextClause(f) or f
- for f in util.to_list(from_obj)
- ])
- else:
- self._froms = util.Set()
-
+
if whereclause:
self._whereclause = _literal_as_text(whereclause)
+ self._froms.update(_from_objects(self._whereclause, is_where=True))
else:
self._whereclause = None
+ if from_obj:
+ self._froms.update([
+ _is_literal(f) and _TextClause(f) or f
+ for f in util.to_list(from_obj)
+ ])
+
if having:
self._having = _literal_as_text(having)
else:
correlating.
"""
- froms = util.OrderedSet()
-
- for col in self._raw_columns:
- froms.update(col._get_from_objects())
-
- if self._whereclause is not None:
- froms.update(self._whereclause._get_from_objects(is_where=True))
-
- if self._froms:
- froms.update(self._froms)
+ froms = self._froms
toremove = itertools.chain(*[f._hide_froms for f in froms])
- froms.difference_update(toremove)
+ if toremove:
+ froms = froms.difference(toremove)
if len(froms) > 1 or self._correlate:
if self._correlate:
- froms.difference_update(_cloned_intersection(froms, self._correlate))
+ froms = froms.difference(_cloned_intersection(froms, self._correlate))
if self._should_correlate and existing_froms:
- froms.difference_update(_cloned_intersection(froms, existing_froms))
+ froms = froms.difference(_cloned_intersection(froms, existing_froms))
if not len(froms):
- raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self)
+ raise exc.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self)
return froms
froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
def type(self):
- raise exceptions.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.")
+ raise exc.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.")
type = property(type)
def locate_all_froms(self):
is specifically for those FromClause elements that would actually be rendered.
"""
- if hasattr(self, '_all_froms'):
- return self._all_froms
-
- froms = util.Set(
- itertools.chain(*
- [self._froms] +
- [f._get_from_objects() for f in self._froms] +
- [col._get_from_objects() for col in self._raw_columns]
- )
- )
+ if not hasattr(self, '_all_froms'):
+ self._all_froms = self._froms.union(_from_objects(*list(self._froms)))
- if self._whereclause:
- froms.update(self._whereclause._get_from_objects(is_where=True))
-
- self._all_froms = froms
- return froms
+ return self._all_froms
def inner_columns(self):
"""an iteratorof all ColumnElement expressions which would
def is_derived_from(self, fromclause):
if self in util.Set(fromclause._cloned_set):
return True
-
+
for f in self.locate_all_froms():
if f.is_derived_from(fromclause):
return True
"""return child elements as per the ClauseElement specification."""
return (column_collections and list(self.columns) or []) + \
- list(self.locate_all_froms()) + \
+ list(self._froms) + \
[x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
def column(self, column):
column = column.self_group(against=operators.comma_op)
s._raw_columns = s._raw_columns + [column]
+ s._froms = s._froms.union(_from_objects(column))
return s
def where(self, whereclause):
"""
s = self._generate()
- s._should_correlate=False
+ s._should_correlate = False
if fromclauses == (None,):
s._correlate = util.Set()
else:
def append_correlation(self, fromclause):
"""append the given correlation expression to this select() construct."""
- self._should_correlate=False
+ self._should_correlate = False
self._correlate = self._correlate.union([fromclause])
def append_column(self, column):
column = column.self_group(against=operators.comma_op)
self._raw_columns = self._raw_columns + [column]
+ self._froms = self._froms.union(_from_objects(column))
self._reset_exported()
def append_prefix(self, clause):
The expression will be joined to existing WHERE criterion via AND.
"""
+ whereclause = _literal_as_text(whereclause)
+ self._froms = self._froms.union(_from_objects(whereclause, is_where=True))
+
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
+ self._whereclause = and_(self._whereclause, whereclause)
else:
- self._whereclause = _literal_as_text(whereclause)
+ self._whereclause = whereclause
def append_having(self, having):
"""append the given expression to this select() construct's HAVING criterion.
return intersect_all(self, other, **kwargs)
- def _table_iterator(self):
- for t in visitors.NoColumnVisitor().iterate(self):
- if isinstance(t, TableClause):
- yield t
-
def bind(self):
if self._bind:
return self._bind
- for f in self._froms:
- if f is self:
- continue
- e = f.bind
- if e:
- self._bind = e
- return e
- # look through the columns (largely synomous with looking
- # through the FROMs except in the case of _CalculatedClause/_Function)
- for c in self._raw_columns:
- if getattr(c, 'table', None) is self:
- continue
- e = c.bind
+ if not self._froms:
+ for c in self._raw_columns:
+ e = c.bind
+ if e:
+ self._bind = e
+ return e
+ else:
+ e = list(self._froms)[0].bind
if e:
self._bind = e
return e
+
return None
+
def _set_bind(self, bind):
self._bind = bind
bind = property(bind, _set_bind)
class _UpdateBase(ClauseElement):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
- def supports_execution(self):
- return True
-
- def _table_iterator(self):
- return iter([self.table])
+ supports_execution = True
def _generate(self):
s = self.__class__.__new__(self.__class__)
self._bind = bind
self.table = table
self.select = None
- self.inline=inline
+ self.inline = inline
if prefixes:
self._prefixes = [_literal_as_text(p) for p in prefixes]
else:
self._whereclause = clone(self._whereclause)
class _IdentifiedClause(ClauseElement):
+ supports_execution = True
+ quote = None
+
def __init__(self, ident):
self.ident = ident
- def supports_execution(self):
- return True
class SavepointClause(_IdentifiedClause):
pass
return a.between(b, c)
def in_op(a, b):
- return a.in_(*b)
+ return a.in_(b)
def notin_op(a, b):
raise NotImplementedError()
-from sqlalchemy import exceptions, schema, topological, util, sql
+from sqlalchemy import exc, schema, topological, util, sql
from sqlalchemy.sql import expression, operators, visitors
from itertools import chain
"""sort a collection of Table objects in order of their foreign-key dependency."""
tuples = []
- class TVisitor(schema.SchemaVisitor):
- def visit_foreign_key(_self, fkey):
- if fkey.use_alter:
- return
- parent_table = fkey.column.table
- if parent_table in tables:
- child_table = fkey.parent.table
- tuples.append( ( parent_table, child_table ) )
- vis = TVisitor()
+ def visit_foreign_key(fkey):
+ if fkey.use_alter:
+ return
+ parent_table = fkey.column.table
+ if parent_table in tables:
+ child_table = fkey.parent.table
+ tuples.append( ( parent_table, child_table ) )
+
for table in tables:
- vis.traverse(table)
+ visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key})
sequence = topological.sort(tuples, tables)
if reverse:
return util.reversed(sequence)
else:
return sequence
-def find_tables(clause, check_columns=False, include_aliases=False):
+def search(clause, target):
+ if not clause:
+ return False
+ for elem in visitors.iterate(clause, {'column_collections':False}):
+ if elem is target:
+ return True
+ else:
+ return False
+
+def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False):
"""locate Table objects within the given expression."""
tables = []
- kwargs = {}
+ _visitors = {}
+
+ def visit_something(elem):
+ tables.append(elem)
+
+ if include_selects:
+ _visitors['select'] = _visitors['compound_select'] = visit_something
+
+ if include_joins:
+ _visitors['join'] = visit_something
+
if include_aliases:
- def visit_alias(alias):
- tables.append(alias)
- kwargs['visit_alias'] = visit_alias
+ _visitors['alias'] = visit_something
if check_columns:
def visit_column(column):
tables.append(column.table)
- kwargs['visit_column'] = visit_column
+ _visitors['column'] = visit_column
- def visit_table(table):
- tables.append(table)
- kwargs['visit_table'] = visit_table
+ _visitors['table'] = visit_something
- visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs)
+ visitors.traverse(clause, {'column_collections':False}, _visitors)
return tables
def find_columns(clause):
cols = util.Set()
def visit_column(col):
cols.add(col)
- visitors.traverse(clause, visit_column=visit_column)
+ visitors.traverse(clause, {}, {'column':visit_column})
return cols
def join_condition(a, b, ignore_nonexistent_tables=False):
for fk in b.foreign_keys:
try:
col = fk.get_referent(a)
- except exceptions.NoReferencedTableError:
+ except exc.NoReferencedTableError:
if ignore_nonexistent_tables:
continue
else:
if col:
crit.append(col == fk.parent)
constraints.add(fk.constraint)
-
if a is not b:
for fk in a.foreign_keys:
try:
col = fk.get_referent(b)
- except exceptions.NoReferencedTableError:
+ except exc.NoReferencedTableError:
if ignore_nonexistent_tables:
continue
else:
raise
-
+
if col:
crit.append(col == fk.parent)
constraints.add(fk.constraint)
if len(crit) == 0:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Can't find any foreign key relationships "
"between '%s' and '%s'" % (a.description, b.description))
elif len(constraints) > 1:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Can't determine join between '%s' and '%s'; "
"tables have more than one foreign key "
"constraint relationship between them. "
return (crit[0])
else:
return sql.and_(*crit)
+
+class Annotated(object):
+ """clones a ClauseElement and applies an 'annotations' dictionary.
+
+ Unlike regular clones, this clone also mimics __hash__() and
+ __cmp__() of the original element so that it takes its place
+ in hashed collections.
+ A reference to the original element is maintained, for the important
+ reason of keeping its hash value current. When GC'ed, the
+ hash value may be reused, causing conflicts.
+
+ """
+ def __new__(cls, *args):
+ if not args:
+ return object.__new__(cls)
+ else:
+ element, values = args
+ return object.__new__(
+ type.__new__(type, "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {})
+ )
+
+ def __init__(self, element, values):
+ self.__dict__ = element.__dict__.copy()
+ self.__element = element
+ self._annotations = values
+
+ def _annotate(self, values):
+ _values = self._annotations.copy()
+ _values.update(values)
+ clone = self.__class__.__new__(self.__class__)
+ clone.__dict__ = self.__dict__.copy()
+ clone._annotations = _values
+ return clone
+
+ def __hash__(self):
+ return hash(self.__element)
+
+ def __cmp__(self, other):
+ return cmp(hash(self.__element), hash(other))
+
+def splice_joins(left, right, stop_on=None):
+ if left is None:
+ return right
+
+ stack = [(right, None)]
+
+ adapter = ClauseAdapter(left)
+ ret = None
+ while stack:
+ (right, prevright) = stack.pop()
+ if isinstance(right, expression.Join) and right is not stop_on:
+ right = right._clone()
+ right._reset_exported()
+ right.onclause = adapter.traverse(right.onclause)
+ stack.append((right.left, right))
+ else:
+ right = adapter.traverse(right)
+ if prevright:
+ prevright.left = right
+ if not ret:
+ ret = right
+
+ return ret
def reduce_columns(columns, *clauses):
"""given a list of columns, return a 'reduced' set based on natural equivalents.
omit.add(c)
break
for clause in clauses:
- visitors.traverse(clause, visit_binary=visit_binary)
+ visitors.traverse(clause, {}, {'binary':visit_binary})
return expression.ColumnSet(columns.difference(omit))
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
- raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
+ raise exc.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
def visit_binary(binary):
if not any_operator and binary.operator != operators.eq:
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
pairs = []
- visitors.traverse(expression, visit_binary=visit_binary)
+ visitors.traverse(expression, {}, {'binary':visit_binary})
return pairs
def folded_equivalents(join, equivs=None):
This function is used by Join.select(fold_equivalents=True).
TODO: deprecate ?
- """
+ """
if equivs is None:
equivs = util.Set()
def visit_binary(binary):
if binary.operator == operators.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
equivs.add(binary.left)
- visitors.traverse(join.onclause, visit_binary=visit_binary)
+ visitors.traverse(join.onclause, {}, {'binary':visit_binary})
collist = []
if isinstance(join.left, expression.Join):
left = folded_equivalents(join.left, equivs)
def keys(self):
return self.row.keys()
-def row_adapter(from_, equivalent_columns=None):
- """create a row adapter callable against a selectable."""
-
- if equivalent_columns is None:
- equivalent_columns = {}
-
- def locate_col(col):
- c = from_.corresponding_column(col)
- if c:
- return c
- elif col in equivalent_columns:
- for c2 in equivalent_columns[col]:
- corr = from_.corresponding_column(c2)
- if corr:
- return corr
- return col
-
- map = util.PopulateDict(locate_col)
-
- def adapt(row):
- return AliasedRow(row, map)
- return adapt
-
-class ColumnsInClause(visitors.ClauseVisitor):
- """Given a selectable, visit clauses and determine if any columns
- from the clause are in the selectable.
- """
-
- def __init__(self, selectable):
- self.selectable = selectable
- self.result = False
-
- def visit_column(self, column):
- if self.selectable.c.get(column.key) is column:
- self.result = True
-class ClauseAdapter(visitors.ClauseVisitor):
+class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""Given a clause (like as in a WHERE criterion), locate columns
which are embedded within a given selectable, and changes those
columns to be that of the selectable.
condition to read::
s.c.col1 == table2.c.col1
- """
-
- __traverse_options__ = {'column_collections':False}
- def __init__(self, selectable, include=None, exclude=None, equivalents=None):
- self.__traverse_options__ = self.__traverse_options__.copy()
- self.__traverse_options__['stop_on'] = [selectable]
+ """
+ def __init__(self, selectable, equivalents=None, include=None, exclude=None):
+ self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]}
self.selectable = selectable
self.include = include
self.exclude = exclude
- self.equivalents = equivalents
-
- def traverse(self, obj, clone=True):
- if not clone:
- raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True")
- return visitors.ClauseVisitor.traverse(self, obj, clone=True)
-
- def copy_and_chain(self, adapter):
- """create a copy of this adapter and chain to the given adapter.
-
- currently this adapter must be unchained to start, raises
- an exception if it's already chained.
-
- Does not modify the given adapter.
- """
+ self.equivalents = equivalents or {}
- if adapter is None:
- return self
+ def _corresponding_column(self, col, require_embedded):
+ newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded)
- if hasattr(self, '_next'):
- raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)")
-
- ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents)
- ca._next = adapter
- return ca
+ if not newcol and col in self.equivalents:
+ for equiv in self.equivalents[col]:
+ newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded)
+ if newcol:
+ return newcol
+ return newcol
- def before_clone(self, col):
+ def replace(self, col):
if isinstance(col, expression.FromClause):
if self.selectable.is_derived_from(col):
return self.selectable
+
if not isinstance(col, expression.ColumnElement):
return None
- if self.include is not None:
- if col not in self.include:
- return None
- if self.exclude is not None:
- if col in self.exclude:
- return None
- newcol = self.selectable.corresponding_column(col, require_embedded=True)
- if newcol is None and self.equivalents is not None and col in self.equivalents:
- for equiv in self.equivalents[col]:
- newcol = self.selectable.corresponding_column(equiv, require_embedded=True)
- if newcol:
- return newcol
- return newcol
+
+ if self.include and col not in self.include:
+ return None
+ elif self.exclude and col in self.exclude:
+ return None
+
+ return self._corresponding_column(col, True)
+
+class ColumnAdapter(ClauseAdapter):
+
+ def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None):
+ ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
+ if chain_to:
+ self.chain(chain_to)
+ self.columns = util.PopulateDict(self._locate_col)
+
+ def wrap(self, adapter):
+ ac = self.__class__.__new__(self.__class__)
+ ac.__dict__ = self.__dict__.copy()
+ ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col)
+ ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause)
+ ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list)
+ ac.columns = util.PopulateDict(ac._locate_col)
+ return ac
+
+ adapt_clause = ClauseAdapter.traverse
+ adapt_list = ClauseAdapter.copy_and_process
+
+ def _wrap(self, local, wrapped):
+ def locate(col):
+ col = local(col)
+ return wrapped(col)
+ return locate
+
+ def _locate_col(self, col):
+ c = self._corresponding_column(col, False)
+ if not c:
+ c = self.adapt_clause(col)
+
+ # anonymize labels in case they have a hardcoded name
+ if isinstance(c, expression._Label):
+ c = c.label(None)
+ return c
+
+ def adapted_row(self, row):
+ return AliasedRow(row, self.columns)
+
from sqlalchemy import util
class ClauseVisitor(object):
- """Traverses and visits ``ClauseElement`` structures.
-
- Calls visit_XXX() methods for each particular
- ``ClauseElement`` subclass encountered. Traversal of a
- hierarchy of ``ClauseElements`` is achieved via the
- ``traverse()`` method, which is passed the lead
- ``ClauseElement``.
-
- By default, ``ClauseVisitor`` traverses all elements
- fully. Options can be specified at the class level via the
- ``__traverse_options__`` dictionary which will be passed
- to the ``get_children()`` method of each ``ClauseElement``;
- these options can indicate modifications to the set of
- elements returned, such as to not return column collections
- (column_collections=False) or to return Schema-level items
- (schema_visitor=True).
-
- ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
- operation, which will produce a copy of a given ``ClauseElement``
- structure while at the same time allowing ``ClauseVisitor`` subclasses
- to modify the new structure in-place.
-
- """
__traverse_options__ = {}
- def traverse_single(self, obj, **kwargs):
- """visit a single element, without traversing its child elements."""
-
+ def traverse_single(self, obj):
for v in self._iterate_visitors:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
- return meth(obj, **kwargs)
+ return meth(obj)
- traverse_chained = traverse_single
-
def iterate(self, obj):
"""traverse the given expression structure, returning an iterator of all elements."""
-
- stack = [obj]
- traversal = util.deque()
- while stack:
- t = stack.pop()
- traversal.appendleft(t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- return iter(traversal)
-
- def traverse(self, obj, clone=False):
- """traverse and visit the given expression structure.
-
- Returns the structure given, or a copy of the structure if
- clone=True.
-
- When the copy operation takes place, the before_clone() method
- will receive each element before it is copied. If the method
- returns a non-None value, the return value is taken as the
- "copied" element and traversal will not descend further.
-
- The visit_XXX() methods receive the element *after* it's been
- copied. To compare an element to another regardless of
- one element being a cloned copy of the original, the
- '_cloned_set' attribute of ClauseElement can be used for the compare,
- i.e.::
-
- original in copied._cloned_set
-
-
- """
- if clone:
- return self._cloned_traversal(obj)
- else:
- return self._non_cloned_traversal(obj)
-
- def copy_and_process(self, list_):
- """Apply cloned traversal to the given list of elements, and return the new list."""
-
- return [self._cloned_traversal(x) for x in list_]
- def before_clone(self, elem):
- """receive pre-copied elements during a cloning traversal.
-
- If the method returns a new element, the element is used
- instead of creating a simple copy of the element. Traversal
- will halt on the newly returned element if it is re-encountered.
- """
- return None
-
- def _clone_element(self, elem, stop_on, cloned):
- for v in self._iterate_visitors:
- newelem = v.before_clone(elem)
- if newelem:
- stop_on.add(newelem)
- return newelem
-
- if elem not in cloned:
- # the full traversal will only make a clone of a particular element
- # once.
- cloned[elem] = elem._clone()
- return cloned[elem]
-
- def _cloned_traversal(self, obj):
- """a recursive traversal which creates copies of elements, returning the new structure."""
-
- stop_on = self.__traverse_options__.get('stop_on', [])
- return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True)
-
- def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False):
- if elem in stop_on:
- return elem
-
- if _clone_toplevel:
- elem = self._clone_element(elem, stop_on, cloned)
- if elem in stop_on:
- return elem
-
- def clone(element):
- return self._clone_element(element, stop_on, cloned)
- elem._copy_internals(clone=clone)
+ return iterate(obj, self.__traverse_options__)
- self.traverse_single(elem)
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
- for e in elem.get_children(**self.__traverse_options__):
- if e not in stop_on:
- self._cloned_traversal_impl(e, stop_on, cloned)
- return elem
+ visitors = {}
- def _non_cloned_traversal(self, obj):
- """a non-recursive, non-cloning traversal."""
-
- for target in self.iterate(obj):
- self.traverse_single(target)
- return obj
+ for name in dir(self):
+ if name.startswith('visit_'):
+ visitors[name[6:]] = getattr(self, name)
+
+ return traverse(obj, self.__traverse_options__, visitors)
def _iterate_visitors(self):
"""iterate through this visitor and each 'chained' visitor."""
tail._next = visitor
return self
-class NoColumnVisitor(ClauseVisitor):
- """ClauseVisitor with 'column_collections' set to False; will not
- traverse the front-facing Column collections on Table, Alias, Select,
- and CompoundSelect objects.
+class CloningVisitor(ClauseVisitor):
+ def copy_and_process(self, list_):
+ """Apply cloned traversal to the given list of elements, and return the new list."""
+
+ return [self.traverse(x) for x in list_]
+
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
+
+ visitors = {}
+
+ for name in dir(self):
+ if name.startswith('visit_'):
+ visitors[name[6:]] = getattr(self, name)
+
+ return cloned_traverse(obj, self.__traverse_options__, visitors)
+
+class ReplacingCloningVisitor(CloningVisitor):
+ def replace(self, elem):
+ """receive pre-copied elements during a cloning traversal.
+
+ If the method returns a new element, the element is used
+ instead of creating a simple copy of the element. Traversal
+ will halt on the newly returned element if it is re-encountered.
+ """
+ return None
+
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
+
+ def replace(elem):
+ for v in self._iterate_visitors:
+ e = v.replace(elem)
+ if e:
+ return e
+ return replacement_traverse(obj, self.__traverse_options__, replace)
+
+def iterate(obj, opts):
+ """traverse the given expression structure, returning an iterator.
+
+ traversal is configured to be breadth-first.
"""
+ stack = util.deque([obj])
+ while stack:
+ t = stack.popleft()
+ yield t
+ for c in t.get_children(**opts):
+ stack.append(c)
+
+def iterate_depthfirst(obj, opts):
+ """traverse the given expression structure, returning an iterator.
- __traverse_options__ = {'column_collections':False}
-
-class NullVisitor(ClauseVisitor):
- def traverse(self, obj, clone=False):
- next = getattr(self, '_next', None)
- if next:
- return next.traverse(obj, clone=clone)
- else:
- return obj
-
-def traverse(clause, **kwargs):
- """traverse the given clause, applying visit functions passed in as keyword arguments."""
+ traversal is configured to be depth-first.
+
+ """
+ stack = util.deque([obj])
+ traversal = util.deque()
+ while stack:
+ t = stack.pop()
+ traversal.appendleft(t)
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return iter(traversal)
+
+def traverse_using(iterator, obj, visitors):
+ """visit the given expression structure using the given iterator of objects."""
+
+ for target in iterator:
+ meth = visitors.get(target.__visit_name__, None)
+ if meth:
+ meth(target)
+ return obj
- clone = kwargs.pop('clone', False)
- class Vis(ClauseVisitor):
- __traverse_options__ = kwargs.pop('traverse_options', {})
- vis = Vis()
- for key in kwargs:
- setattr(vis, key, kwargs[key])
- return vis.traverse(clause, clone=clone)
+def traverse(obj, opts, visitors):
+ """traverse and visit the given expression structure using the default iterator."""
+
+ return traverse_using(iterate(obj, opts), obj, visitors)
+
+def traverse_depthfirst(obj, opts, visitors):
+ """traverse and visit the given expression structure using the depth-first iterator."""
+
+ return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
+
+def cloned_traverse(obj, opts, visitors):
+ cloned = {}
+
+ def clone(element):
+ if element not in cloned:
+ cloned[element] = element._clone()
+ return cloned[element]
+
+ obj = clone(obj)
+ stack = [obj]
+
+ while stack:
+ t = stack.pop()
+ if t in cloned:
+ continue
+ t._copy_internals(clone=clone)
+
+ meth = visitors.get(t.__visit_name__, None)
+ if meth:
+ meth(t)
+
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return obj
+
+def replacement_traverse(obj, opts, replace):
+ cloned = {}
+ stop_on = util.Set(opts.get('stop_on', []))
+
+ def clone(element):
+ newelem = replace(element)
+ if newelem:
+ stop_on.add(newelem)
+ return newelem
+
+ if element not in cloned:
+ cloned[element] = element._clone()
+ return cloned[element]
+ obj = clone(obj)
+ stack = [obj]
+ while stack:
+ t = stack.pop()
+ if t in stop_on:
+ continue
+ t._copy_internals(clone=clone)
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return obj
"""
from sqlalchemy import util
-from sqlalchemy.exceptions import CircularDependencyError
+from sqlalchemy.exc import CircularDependencyError
__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree']
for n in lead.cycles:
if n is not lead:
n._cyclical = True
- for (n,k) in list(edges.edges_by_parent(n)):
+ for (n, k) in list(edges.edges_by_parent(n)):
edges.add((lead, k))
- edges.remove((n,k))
+ edges.remove((n, k))
continue
else:
# long cycles not allowed
nodealldeps = node.all_deps()
if nodealldeps:
# iterate over independent node indexes in reverse order so we can efficiently remove them
- for index in xrange(len(independents)-1,-1,-1):
+ for index in xrange(len(independents) - 1, -1, -1):
child, childsubtree, childcycles = independents[index]
# if there is a dependency between this node and an independent node
if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)):
# remove the child from list of independent subtrees
independents[index:index+1] = []
# add node as a new independent subtree
- independents.append((node,subtree,cycles))
+ independents.append((node, subtree, cycles))
# choose an arbitrary node from list of all independent subtrees
head = independents.pop()[0]
# add all other independent subtrees as a child of the chosen root
import inspect
import datetime as dt
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from sqlalchemy.util import pickle, Decimal as _python_Decimal
import sqlalchemy.util as util
NoneType = type(None)
def get_col_spec(self):
raise NotImplementedError()
-
def bind_processor(self, dialect):
return None
def __init__(self, *args, **kwargs):
if not hasattr(self.__class__, 'impl'):
- raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
+ raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
self.impl = self.__class__.impl(*args, **kwargs)
def dialect_impl(self, dialect, **kwargs):
typedesc = self.load_dialect_impl(dialect)
tt = self.copy()
if not isinstance(tt, self.__class__):
- raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
+ raise AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
tt.impl = typedesc
self._impl_dict[dialect] = tt
return tt
return self.impl.copy_value(value)
def compare_values(self, x, y):
- return self.impl.compare_values(x,y)
+ return self.impl.compare_values(x, y)
def is_mutable(self):
return self.impl.is_mutable()
class String(Concatenable, TypeEngine):
"""A sized string type.
- Usually corresponds to VARCHAR. Can also take Python unicode objects
+ In SQL, corresponds to VARCHAR. Can also take Python unicode objects
and encode to the database's encoding in bind params (and the reverse for
result sets.)
- a String with no length will adapt itself automatically to a Text
- object at the dialect level (this behavior is deprecated in 0.4).
+ The `length` field is usually required when the `String` type is used within a
+ CREATE TABLE statement, since VARCHAR requires a length on most databases.
+ Currently SQLite is an exception to this.
+
"""
def __init__(self, length=None, convert_unicode=False, assert_unicode=None):
self.length = length
"param value %r" % value)
return value
else:
- raise exceptions.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+ raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
else:
return value
return process
else:
return None
- def dialect_impl(self, dialect, **kwargs):
- _for_ddl = kwargs.pop('_for_ddl', False)
- if _for_ddl and self.length is None:
- label = util.to_ascii(_for_ddl is True and
- '' or (' for column "%s"' % str(_for_ddl)))
- util.warn_deprecated(
- "Using String type with no length for CREATE TABLE "
- "is deprecated; use the Text type explicitly" + label)
- return TypeEngine.dialect_impl(self, dialect, **kwargs)
-
- def get_search_list(self):
- l = super(String, self).get_search_list()
- # if we are String or Unicode with no length,
- # return Text as the highest-priority type
- # to be adapted by the dialect
- if self.length is None and l[0] in (String, Unicode):
- return (Text,) + l
- else:
- return l
-
def get_dbapi_type(self, dbapi):
return dbapi.STRING
if value is None:
return None
return dt.datetime.utcfromtimestamp(0) + value
-
+
def process_result_value(self, value, dialect):
if dialect.__class__ in self.__supported:
return value
return None
return value - dt.datetime.utcfromtimestamp(0)
-class FLOAT(Float): pass
-TEXT = Text
-class NUMERIC(Numeric): pass
-class DECIMAL(Numeric): pass
-class INT(Integer): pass
+class FLOAT(Float):
+ """The SQL FLOAT type."""
+
+
+class NUMERIC(Numeric):
+ """The SQL NUMERIC type."""
+
+
+class DECIMAL(Numeric):
+ """The SQL DECIMAL type."""
+
+
+class INT(Integer):
+ """The SQL INT or INTEGER type."""
+
+
INTEGER = INT
-class SMALLINT(Smallinteger): pass
-class TIMESTAMP(DateTime): pass
-class DATETIME(DateTime): pass
-class DATE(Date): pass
-class TIME(Time): pass
-class CLOB(Text): pass
-class VARCHAR(String): pass
-class CHAR(String): pass
-class NCHAR(Unicode): pass
-class BLOB(Binary): pass
-class BOOLEAN(Boolean): pass
+
+class SMALLINT(Smallinteger):
+ """The SQL SMALLINT type."""
+
+
+class TIMESTAMP(DateTime):
+ """The SQL TIMESTAMP type."""
+
+
+class DATETIME(DateTime):
+ """The SQL DATETIME type."""
+
+
+class DATE(Date):
+ """The SQL DATE type."""
+
+
+class TIME(Time):
+ """The SQL TIME type."""
+
+
+TEXT = Text
+
+class CLOB(Text):
+ """The SQL CLOB type."""
+
+
+class VARCHAR(String):
+ """The SQL VARCHAR type."""
+
+
+class CHAR(String):
+ """The SQL CHAR type."""
+
+
+class NCHAR(Unicode):
+ """The SQL NCHAR type."""
+
+
+class BLOB(Binary):
+ """The SQL BLOB type."""
+
+
+class BOOLEAN(Boolean):
+ """The SQL BOOLEAN type."""
NULLTYPE = NullType()
import __builtin__
types = __import__('types')
-from sqlalchemy import exceptions
+from sqlalchemy import exc
try:
import thread, threading
try:
Set = set
+ FrozenSet = frozenset
set_types = set, sets.Set
except NameError:
set_types = sets.Set,
- # layer some of __builtin__.set's binop behavior onto sets.Set
- class Set(sets.Set):
+
+ def py24_style_ops():
+ """Layer some of __builtin__.set's binop behavior onto sets.Set."""
+
def _binary_sanity_check(self, other):
pass
-
def issubset(self, iterable):
other = type(self)(iterable)
return sets.Set.issubset(self, other)
def __ge__(self, other):
sets.Set._binary_sanity_check(self, other)
return sets.Set.__ge__(self, other)
-
# lt and gt still require a BaseSet
def __lt__(self, other):
sets.Set._binary_sanity_check(self, other)
if not isinstance(other, sets.BaseSet):
return NotImplemented
return sets.Set.__isub__(self, other)
+ return locals()
+
+ py24_style_ops = py24_style_ops()
+ Set = type('Set', (sets.Set,), py24_style_ops)
+ FrozenSet = type('FrozenSet', (sets.ImmutableSet,), py24_style_ops)
+ del py24_style_ops
+
+EMPTY_SET = FrozenSet()
try:
import cPickle as pickle
try:
from operator import attrgetter
-except:
+except ImportError:
def attrgetter(attribute):
return lambda value: getattr(value, attribute)
+try:
+ from operator import itemgetter
+except ImportError:
+ def itemgetter(attribute):
+ return lambda value: value[attribute]
+
if sys.version_info >= (2, 5):
class PopulateDict(dict):
"""a dict which populates missing values via a creation function.
class deque(list):
def appendleft(self, x):
self.insert(0, x)
-
+
def extendleft(self, iterable):
self[0:0] = list(iterable)
def popleft(self):
return self.pop(0)
-
+
def rotate(self, n):
for i in xrange(n):
self.appendleft(self.pop())
-
+
def to_list(x, default=None):
if x is None:
return default
else:
return x
-def array_as_starargs_decorator(func):
+def array_as_starargs_decorator(fn):
"""Interpret a single positional array argument as
*args for the decorated method.
-
+
"""
+
def starargs_as_list(self, *args, **kwargs):
- if len(args) == 1:
- return func(self, *to_list(args[0], []), **kwargs)
+ if isinstance(args, basestring) or (len(args) == 1 and not isinstance(args[0], tuple)):
+ return fn(self, *to_list(args[0], []), **kwargs)
else:
- return func(self, *args, **kwargs)
- return starargs_as_list
-
+ return fn(self, *args, **kwargs)
+ starargs_as_list.__doc__ = fn.__doc__
+ return function_named(starargs_as_list, fn.__name__)
+
+def array_as_starargs_fn_decorator(fn):
+ """Interpret a single positional array argument as
+ *args for the decorated function.
+
+ """
+
+ def starargs_as_list(*args, **kwargs):
+ if isinstance(args, basestring) or (len(args) == 1 and not isinstance(args[0], tuple)):
+ return fn(*to_list(args[0], []), **kwargs)
+ else:
+ return fn(*args, **kwargs)
+ starargs_as_list.__doc__ = fn.__doc__
+ return function_named(starargs_as_list, fn.__name__)
+
def to_set(x):
if x is None:
return Set()
"""Return the full set of legal kwargs for the given `func`."""
return inspect.getargspec(func)[0]
+def format_argspec_plus(fn, grouped=True):
+ """Returns a dictionary of formatted, introspected function arguments.
+
+ A enhanced variant of inspect.formatargspec to support code generation.
+
+ fn
+ An inspectable callable
+ grouped
+ Defaults to True; include (parens, around, argument) lists
+
+ Returns:
+
+ args
+ Full inspect.formatargspec for fn
+ self_arg
+ The name of the first positional argument, or None
+ apply_pos
+ args, re-written in calling rather than receiving syntax. Arguments are
+ passed positionally.
+ apply_kw
+ Like apply_pos, except keyword-ish args are passed as keywords.
+
+ Example::
+
+ >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
+ {'args': '(self, a, b, c=3, **d)',
+ 'self_arg': 'self',
+ 'apply_kw': '(self, a, b, c=c, **d)',
+ 'apply_pos': '(self, a, b, c, **d)'}
+
+ """
+ spec = inspect.getargspec(fn)
+ args = inspect.formatargspec(*spec)
+ self_arg = spec[0] and spec[0][0] or None
+ apply_pos = inspect.formatargspec(spec[0], spec[1], spec[2])
+ defaulted_vals = spec[3] is not None and spec[0][0-len(spec[3]):] or ()
+ apply_kw = inspect.formatargspec(spec[0], spec[1], spec[2], defaulted_vals,
+ formatvalue=lambda x: '=' + x)
+ if grouped:
+ return dict(args=args, self_arg=self_arg,
+ apply_pos=apply_pos, apply_kw=apply_kw)
+ else:
+ return dict(args=args[1:-1], self_arg=self_arg,
+ apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1])
+
+def format_argspec_init(method, grouped=True):
+ """format_argspec_plus with considerations for typical __init__ methods
+
+ Wraps format_argspec_plus with error handling strategies for typical
+ __init__ cases::
+
+ object.__init__ -> (self)
+ other unreflectable (usually C) -> (self, *args, **kwargs)
+
+ """
+ try:
+ return format_argspec_plus(method, grouped=grouped)
+ except TypeError:
+ self_arg = 'self'
+ if method is object.__init__:
+ args = grouped and '(self)' or 'self'
+ else:
+ args = (grouped and '(self, *args, **kwargs)'
+ or 'self, *args, **kwargs')
+ return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args)
+
+def getargspec_init(method):
+ """inspect.getargspec with considerations for typical __init__ methods
+
+ Wraps inspect.getargspec with error handling for typical __init__ cases::
+
+ object.__init__ -> (self)
+ other unreflectable (usually C) -> (self, *args, **kwargs)
+
+ """
+ try:
+ return inspect.getargspec(method)
+ except TypeError:
+ if method is object.__init__:
+ return (['self'], None, None, None)
+ else:
+ return (['self'], 'args', 'kwargs', None)
+
def unbound_method_to_callable(func_or_cls):
"""Adjust the incoming callable such that a 'self' argument is not required."""
-
+
if isinstance(func_or_cls, types.MethodType) and not func_or_cls.im_self:
return func_or_cls.im_func
else:
return func_or_cls
+def class_hierarchy(cls):
+ """Return an unordered sequence of all classes related to cls.
+
+ Traverses diamond hierarchies.
+
+ Fibs slightly: subclasses of builtin types are not returned. Thus
+ class_hierarchy(class A(object)) returns (A, object), not A plus every
+ class systemwide that derives from object.
+
+ """
+ hier = Set([cls])
+ process = list(cls.__mro__)
+ while process:
+ c = process.pop()
+ for b in [_ for _ in c.__bases__ if _ not in hier]:
+ process.append(b)
+ hier.add(b)
+ if c.__module__ == '__builtin__':
+ continue
+ for s in [_ for _ in c.__subclasses__() if _ not in hier]:
+ process.append(s)
+ hier.add(s)
+ return list(hier)
+
# from paste.deploy.converters
def asbool(obj):
if isinstance(obj, (str, unicode)):
return specimen.__emulates__
isa = isinstance(specimen, type) and issubclass or isinstance
- if isa(specimen, list): return list
- if isa(specimen, set_types): return Set
- if isa(specimen, dict): return dict
+ if isa(specimen, list):
+ return list
+ elif isa(specimen, set_types):
+ return Set
+ elif isa(specimen, dict):
+ return dict
if hasattr(specimen, 'append'):
return list
return arg
else:
if isinstance(argtype, tuple):
- raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg))))
+ raise exc.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg))))
else:
- raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg))))
+ raise exc.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg))))
+_creation_order = 1
+def set_creation_order(instance):
+ """assign a '_creation_order' sequence to the given instance.
+
+ This allows multiple instances to be sorted in order of
+ creation (typically within a single thread; the counter is
+ not particularly threadsafe).
+
+ """
+ global _creation_order
+ instance._creation_order = _creation_order
+ _creation_order +=1
+
def warn_exception(func, *args, **kwargs):
"""executes the given function, catches all exceptions and converts to a warning."""
try:
class NotImplProperty(object):
- """a property that raises ``NotImplementedError``."""
+ """a property that raises ``NotImplementedError``."""
- def __init__(self, doc):
- self.__doc__ = doc
+ def __init__(self, doc):
+ self.__doc__ = doc
- def __set__(self, obj, value):
- raise NotImplementedError()
+ def __set__(self, obj, value):
+ raise NotImplementedError()
- def __delete__(self, obj):
- raise NotImplementedError()
+ def __delete__(self, obj):
+ raise NotImplementedError()
- def __get__(self, obj, owner):
- if obj is None:
- return self
- else:
- raise NotImplementedError()
+ def __get__(self, obj, owner):
+ if obj is None:
+ return self
+ else:
+ raise NotImplementedError()
class OrderedProperties(object):
"""An object that maintains the order in which attributes are set upon it.
def __contains__(self, key):
return key in self._data
-
+
def update(self, value):
self._data.update(value)
-
+
def get(self, key, default=None):
if key in self:
return self[key]
def clear(self):
self._list = []
dict.clear(self)
-
+
+ def sort(self, fn=None):
+ self._list.sort(fn)
+
def update(self, ____sequence=None, **kwargs):
if ____sequence is not None:
if hasattr(____sequence, 'keys'):
if d is not None:
self.update(d)
- def add(self, key):
- if key not in self:
- self._list.append(key)
- Set.add(self, key)
+ def add(self, element):
+ if element not in self:
+ self._list.append(element)
+ Set.add(self, element)
def remove(self, element):
Set.remove(self, element)
self._list.remove(element)
+ def insert(self, pos, element):
+ if element not in self:
+ self._list.insert(pos, element)
+ Set.add(self, element)
+
def discard(self, element):
- try:
- Set.remove(self, element)
- except KeyError:
- pass
- else:
+ if element in self:
self._list.remove(element)
+ Set.remove(self, element)
def clear(self):
Set.clear(self)
return iter(self._list)
def __repr__(self):
- return '%s(%r)' % (self.__class__.__name__, self._list)
+ return '%s(%r)' % (self.__class__.__name__, self._list)
__str__ = __repr__
def update(self, iterable):
- add = self.add
- for i in iterable:
- add(i)
- return self
+ add = self.add
+ for i in iterable:
+ add(i)
+ return self
__ior__ = update
def union(self, other):
- result = self.__class__(self)
- result.update(other)
- return result
+ result = self.__class__(self)
+ result.update(other)
+ return result
__or__ = union
__iand__ = intersection_update
def symmetric_difference_update(self, other):
- Set.symmetric_difference_update(self, other)
- self._list = [ a for a in self._list if a in self]
- self._list += [ a for a in other._list if a in self]
- return self
+ Set.symmetric_difference_update(self, other)
+ self._list = [ a for a in self._list if a in self]
+ self._list += [ a for a in other._list if a in self]
+ return self
__ixor__ = symmetric_difference_update
def _get_key(self):
return self.scopefunc()
+class WeakCompositeKey(object):
+ """an weak-referencable, hashable collection which is strongly referenced
+ until any one of its members is garbage collected.
+
+ """
+ keys = Set()
+
+ def __init__(self, *args):
+ self.args = [self.__ref(arg) for arg in args]
+ WeakCompositeKey.keys.add(self)
+
+ def __ref(self, arg):
+ if isinstance(arg, type):
+ return weakref.ref(arg, self.__remover)
+ else:
+ return lambda: arg
+
+ def __remover(self, wr):
+ WeakCompositeKey.keys.discard(self)
+
+ def __hash__(self):
+ return hash(tuple(self))
+
+ def __cmp__(self, other):
+ return cmp(tuple(self), tuple(other))
+
+ def __iter__(self):
+ return iter([arg() for arg in self.args])
+
class _symbol(object):
def __init__(self, name):
"""Construct a new named symbol."""
finally:
symbol._lock.release()
-
def as_interface(obj, cls=None, methods=None, required=None):
"""Ensure basic interface compliance for an instance or dict of callables.
fn.func_defaults, fn.func_closure)
return fn
-def conditional_cache_decorator(func):
- """apply conditional caching to the return value of a function."""
-
- return cache_decorator(func, conditional=True)
-
-def cache_decorator(func, conditional=False):
+def cache_decorator(func):
"""apply caching to the return value of a function."""
name = '_cached_' + func.__name__
-
+
def do_with_cache(self, *args, **kwargs):
- if conditional:
- cache = kwargs.pop('cache', False)
- if not cache:
- return func(self, *args, **kwargs)
try:
return getattr(self, name)
except AttributeError:
setattr(self, name, value)
return value
return do_with_cache
-
+
def reset_cached(instance, name):
try:
delattr(instance, '_cached_' + name)
except AttributeError:
pass
+class WeakIdentityMapping(weakref.WeakKeyDictionary):
+ """A WeakKeyDictionary with an object identity index.
+
+ Adds a .by_id dictionary to a regular WeakKeyDictionary. Trades
+ performance during mutation operations for accelerated lookups by id().
+
+ The usual cautions about weak dictionaries and iteration also apply to
+ this subclass.
+
+ """
+ _none = symbol('none')
+
+ def __init__(self):
+ weakref.WeakKeyDictionary.__init__(self)
+ self.by_id = {}
+ self._weakrefs = {}
+
+ def __setitem__(self, object, value):
+ oid = id(object)
+ self.by_id[oid] = value
+ if oid not in self._weakrefs:
+ self._weakrefs[oid] = self._ref(object)
+ weakref.WeakKeyDictionary.__setitem__(self, object, value)
+
+ def __delitem__(self, object):
+ del self._weakrefs[id(object)]
+ del self.by_id[id(object)]
+ weakref.WeakKeyDictionary.__delitem__(self, object)
+
+ def setdefault(self, object, default=None):
+ value = weakref.WeakKeyDictionary.setdefault(self, object, default)
+ oid = id(object)
+ if value is default:
+ self.by_id[oid] = default
+ if oid not in self._weakrefs:
+ self._weakrefs[oid] = self._ref(object)
+ return value
+
+ def pop(self, object, default=_none):
+ if default is self._none:
+ value = weakref.WeakKeyDictionary.pop(self, object)
+ else:
+ value = weakref.WeakKeyDictionary.pop(self, object, default)
+ if id(object) in self.by_id:
+ del self._weakrefs[id(object)]
+ del self.by_id[id(object)]
+ return value
+
+ def popitem(self):
+ item = weakref.WeakKeyDictionary.popitem(self)
+ oid = id(item[0])
+ del self._weakrefs[oid]
+ del self.by_id[oid]
+ return item
+
+ def clear(self):
+ self._weakrefs.clear()
+ self.by_id.clear()
+ weakref.WeakKeyDictionary.clear(self)
+
+ def update(self, *a, **kw):
+ raise NotImplementedError
+
+ def _cleanup(self, wr, key=None):
+ if key is None:
+ key = wr.key
+ try:
+ del self._weakrefs[key]
+ except (KeyError, AttributeError): # pragma: no cover
+ pass # pragma: no cover
+ try:
+ del self.by_id[key]
+ except (KeyError, AttributeError): # pragma: no cover
+ pass # pragma: no cover
+ if sys.version_info < (2, 4): # pragma: no cover
+ def _ref(self, object):
+ oid = id(object)
+ return weakref.ref(object, lambda wr: self._cleanup(wr, oid))
+ else:
+ class _keyed_weakref(weakref.ref):
+ def __init__(self, object, callback):
+ weakref.ref.__init__(self, object, callback)
+ self.key = id(object)
+
+ def _ref(self, object):
+ return self._keyed_weakref(object, self._cleanup)
+
+
def warn(msg):
if isinstance(msg, basestring):
- warnings.warn(msg, exceptions.SAWarning, stacklevel=3)
+ warnings.warn(msg, exc.SAWarning, stacklevel=3)
else:
warnings.warn(msg, stacklevel=3)
def warn_deprecated(msg):
- warnings.warn(msg, exceptions.SADeprecationWarning, stacklevel=3)
+ warnings.warn(msg, exc.SADeprecationWarning, stacklevel=3)
def deprecated(message=None, add_deprecation_to_docstring=True):
"""Decorates a function and issues a deprecation warning on use.
def decorate(fn):
return _decorate_with_warning(
- fn, exceptions.SADeprecationWarning,
+ fn, exc.SADeprecationWarning,
message % dict(func=fn.__name__), header)
return decorate
def decorate(fn):
return _decorate_with_warning(
- fn, exceptions.SAPendingDeprecationWarning,
+ fn, exc.SAPendingDeprecationWarning,
message % dict(func=fn.__name__), header)
return decorate
import testenv; testenv.configure_for_tests()
import sqlalchemy.topological as topological
from sqlalchemy import util
-from testlib import *
+from testlib import TestBase
class DependencySortTest(TestBase):
"""Tests exceptions and DB-API exception wrapping."""
import testenv; testenv.configure_for_tests()
-import sys, unittest
+import unittest
import exceptions as stdlib_exceptions
-from sqlalchemy import exceptions as sa_exceptions
-from testlib import *
+from sqlalchemy import exc as sa_exceptions
class Error(stdlib_exceptions.StandardError):
# subclasses of sqlalchemy.exceptions.DBAPIError
try:
raise sa_exceptions.DBAPIError.instance(
- '', [], sa_exceptions.AssertionError())
+ '', [], sa_exceptions.ArgumentError())
except sa_exceptions.DBAPIError, e:
self.assert_(e.__class__ is sa_exceptions.DBAPIError)
- except sa_exceptions.AssertionError:
+ except sa_exceptions.ArgumentError:
self.assert_(False)
def test_db_error_keyboard_interrupt(self):
import testenv; testenv.configure_for_tests()
-import unittest
-from sqlalchemy import util, sql, exceptions
-from testlib import *
-from testlib import sorted
+import threading, unittest
+from sqlalchemy import util, sql, exc
+from testlib import TestBase
+from testlib.testing import eq_, is_, ne_
+from testlib.compat import frozenset, set, sorted
class OrderedDictTest(TestBase):
def test_odict(self):
o['snack'] = 'attack'
o['c'] = 3
- self.assert_(o.keys() == ['a', 'b', 'snack', 'c'])
- self.assert_(o.values() == [1, 2, 'attack', 3])
+ eq_(o.keys(), ['a', 'b', 'snack', 'c'])
+ eq_(o.values(), [1, 2, 'attack', 3])
o.pop('snack')
- self.assert_(o.keys() == ['a', 'b', 'c'])
- self.assert_(o.values() == [1, 2, 3])
+ eq_(o.keys(), ['a', 'b', 'c'])
+ eq_(o.values(), [1, 2, 3])
o2 = util.OrderedDict(d=4)
o2['e'] = 5
- self.assert_(o2.keys() == ['d', 'e'])
- self.assert_(o2.values() == [4, 5])
+ eq_(o2.keys(), ['d', 'e'])
+ eq_(o2.values(), [4, 5])
o.update(o2)
- self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e'])
- self.assert_(o.values() == [1, 2, 3, 4, 5])
+ eq_(o.keys(), ['a', 'b', 'c', 'd', 'e'])
+ eq_(o.values(), [1, 2, 3, 4, 5])
o.setdefault('c', 'zzz')
o.setdefault('f', 6)
- self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e', 'f'])
- self.assert_(o.values() == [1, 2, 3, 4, 5, 6])
+ eq_(o.keys(), ['a', 'b', 'c', 'd', 'e', 'f'])
+ eq_(o.values(), [1, 2, 3, 4, 5, 6])
class OrderedSetTest(TestBase):
def test_mutators_against_iter(self):
# testing a set modified against an iterator
o = util.OrderedSet([3,2, 4, 5])
- self.assertEquals(o.difference(iter([3,4])),
- util.OrderedSet([2,5]))
- self.assertEquals(o.intersection(iter([3,4, 6])),
- util.OrderedSet([3, 4]))
- self.assertEquals(o.union(iter([3,4, 6])),
- util.OrderedSet([2, 3, 4, 5, 6]))
+ eq_(o.difference(iter([3,4])), util.OrderedSet([2,5]))
+ eq_(o.intersection(iter([3,4, 6])), util.OrderedSet([3, 4]))
+ eq_(o.union(iter([3,4, 6])), util.OrderedSet([2, 3, 4, 5, 6]))
class ColumnCollectionTest(TestBase):
def test_in(self):
try:
cc['col1'] in cc
assert False
- except exceptions.ArgumentError, e:
- assert str(e) == "__contains__ requires a string argument"
+ except exc.ArgumentError, e:
+ eq_(str(e), "__contains__ requires a string argument")
def test_compare(self):
cc1 = sql.ColumnCollection()
m3 = MyClass(3, 4)
assert m1 is m3
assert m2 is not m3
- assert len(util.ArgSingleton.instances) == 2
+ eq_(len(util.ArgSingleton.instances), 2)
m1 = m2 = m3 = None
MyClass.dispose(MyClass)
- assert len(util.ArgSingleton.instances) == 0
+ eq_(len(util.ArgSingleton.instances), 0)
class ImmutableSubclass(str):
def assert_eq(self, identityset, expected_iterable):
expected = sorted([id(o) for o in expected_iterable])
found = sorted([id(o) for o in identityset])
- self.assertEquals(found, expected)
+ eq_(found, expected)
def test_init(self):
ids = util.IdentitySet([1,2,3,2,1])
ids.remove(o1)
self.assertRaises(KeyError, ids.remove, o1)
- self.assert_(ids.copy() == ids)
- self.assert_(ids != None)
- self.assert_(not(ids == None))
- self.assert_(ids != IdentitySet([o1,o2,o3]))
+ eq_(ids.copy(), ids)
+
+ # explicit __eq__ and __ne__ tests
+ assert ids != None
+ assert not(ids == None)
+
+ ne_(ids, IdentitySet([o1,o2,o3]))
ids.clear()
- self.assert_(o1 not in ids)
+ assert o1 not in ids
ids.add(o2)
- self.assert_(o2 in ids)
- self.assert_(ids.pop() == o2)
+ assert o2 in ids
+ eq_(ids.pop(), o2)
ids.add(o1)
- self.assert_(len(ids) == 1)
+ eq_(len(ids), 1)
isuper = IdentitySet([o1,o2])
- self.assert_(ids < isuper)
- self.assert_(ids.issubset(isuper))
- self.assert_(isuper.issuperset(ids))
- self.assert_(isuper > ids)
-
- self.assert_(ids.union(isuper) == isuper)
- self.assert_(ids | isuper == isuper)
- self.assert_(isuper - ids == IdentitySet([o2]))
- self.assert_(isuper.difference(ids) == IdentitySet([o2]))
- self.assert_(ids.intersection(isuper) == IdentitySet([o1]))
- self.assert_(ids & isuper == IdentitySet([o1]))
- self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2]))
- self.assert_(ids ^ isuper == IdentitySet([o2]))
+ assert ids < isuper
+ assert ids.issubset(isuper)
+ assert isuper.issuperset(ids)
+ assert isuper > ids
+
+ eq_(ids.union(isuper), isuper)
+ eq_(ids | isuper, isuper)
+ eq_(isuper - ids, IdentitySet([o2]))
+ eq_(isuper.difference(ids), IdentitySet([o2]))
+ eq_(ids.intersection(isuper), IdentitySet([o1]))
+ eq_(ids & isuper, IdentitySet([o1]))
+ eq_(ids.symmetric_difference(isuper), IdentitySet([o2]))
+ eq_(ids ^ isuper, IdentitySet([o2]))
ids.update(isuper)
ids |= isuper
ids.update('foobar')
try:
ids |= 'foobar'
- self.assert_(False)
+ assert False
except TypeError:
- self.assert_(True)
+ assert True
try:
s = set([o1,o2])
s |= ids
- self.assert_(False)
+ assert False
except TypeError:
- self.assert_(True)
+ assert True
self.assertRaises(TypeError, cmp, ids)
self.assertRaises(TypeError, hash, ids)
s1 = set([1,2,3])
s2 = set([3,4,5])
- self.assertEquals(os1 - os2, util.IdentitySet([1, 2]))
- self.assertEquals(os2 - os1, util.IdentitySet([4, 5]))
+ eq_(os1 - os2, util.IdentitySet([1, 2]))
+ eq_(os2 - os1, util.IdentitySet([4, 5]))
self.assertRaises(TypeError, lambda: os1 - s2)
self.assertRaises(TypeError, lambda: os1 - [3, 4, 5])
self.assertRaises(TypeError, lambda: s1 - os2)
def _ok(self, instance):
iterator = util.dictlike_iteritems(instance)
- self.assertEquals(set(iterator), self.baseline)
+ eq_(set(iterator), self.baseline)
def _notok(self, instance):
self.assertRaises(TypeError,
self._notok(duck6())
+class DuckTypeCollectionTest(TestBase):
+ def test_sets(self):
+ import sets
+ class SetLike(object):
+ def add(self):
+ pass
+
+ class ForcedSet(list):
+ __emulates__ = set
+
+ for type_ in (set,
+ sets.Set,
+ util.Set,
+ SetLike,
+ ForcedSet):
+ eq_(util.duck_type_collection(type_), util.Set)
+ instance = type_()
+ eq_(util.duck_type_collection(instance), util.Set)
+
+ for type_ in (frozenset,
+ sets.ImmutableSet,
+ util.FrozenSet):
+ is_(util.duck_type_collection(type_), None)
+ instance = type_()
+ is_(util.duck_type_collection(instance), None)
+
+
class ArgInspectionTest(TestBase):
def test_get_cls_kwargs(self):
class A(object):
pass
def test(cls, *expected):
- self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected))
+ eq_(set(util.get_cls_kwargs(cls)), set(expected))
test(A, 'a')
test(A1, 'a1')
def f4(**foo): pass
def test(fn, *expected):
- self.assertEquals(set(util.get_func_kwargs(fn)), set(expected))
+ eq_(set(util.get_func_kwargs(fn)), set(expected))
test(f1)
test(f2, 'foo')
assert rt is sym1
assert rt is sym2
+class WeakIdentityMappingTest(TestBase):
+ class Data(object):
+ pass
+
+ def _some_data(self, some=20):
+ return [self.Data() for _ in xrange(some)]
+
+ def _fixture(self, some=20):
+ data = self._some_data()
+ wim = util.WeakIdentityMapping()
+ for idx, obj in enumerate(data):
+ wim[obj] = idx
+ return data, wim
+
+ def test_delitem(self):
+ data, wim = self._fixture()
+ needle = data[-1]
+
+ assert needle in wim
+ assert id(needle) in wim.by_id
+ eq_(wim[needle], wim.by_id[id(needle)])
+
+ del wim[needle]
+
+ assert needle not in wim
+ assert id(needle) not in wim.by_id
+ eq_(len(wim), (len(data) - 1))
+
+ data.remove(needle)
+
+ assert needle not in wim
+ assert id(needle) not in wim.by_id
+ eq_(len(wim), len(data))
+
+ def test_setitem(self):
+ data, wim = self._fixture()
+
+ o1, oid1 = data[-1], id(data[-1])
+
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(wim[o1], wim.by_id[oid1])
+ id_keys = set(wim.by_id.keys())
+
+ wim[o1] = 1234
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(wim[o1], wim.by_id[oid1])
+ eq_(set(wim.by_id.keys()), id_keys)
+
+ o2 = self.Data()
+ oid2 = id(o2)
+
+ wim[o2] = 5678
+ assert o2 in wim
+ assert oid2 in wim.by_id
+ eq_(wim[o2], wim.by_id[oid2])
+
+ def test_pop(self):
+ data, wim = self._fixture()
+ needle = data[-1]
+
+ needle = data.pop()
+ assert needle in wim
+ assert id(needle) in wim.by_id
+ eq_(wim[needle], wim.by_id[id(needle)])
+ eq_(len(wim), (len(data) + 1))
+
+ wim.pop(needle)
+ assert needle not in wim
+ assert id(needle) not in wim.by_id
+ eq_(len(wim), len(data))
+
+ def test_pop_default(self):
+ data, wim = self._fixture()
+ needle = data[-1]
+
+ value = wim[needle]
+ x = wim.pop(needle, 123)
+ ne_(x, 123)
+ eq_(x, value)
+ assert needle not in wim
+ assert id(needle) not in wim.by_id
+ eq_(len(data), (len(wim) + 1))
+
+ n2 = self.Data()
+ y = wim.pop(n2, 456)
+ eq_(y, 456)
+ assert n2 not in wim
+ assert id(n2) not in wim.by_id
+ eq_(len(data), (len(wim) + 1))
+
+ def test_popitem(self):
+ data, wim = self._fixture()
+ (needle, idx) = wim.popitem()
+
+ assert needle in data
+ eq_(len(data), (len(wim) + 1))
+ assert id(needle) not in wim.by_id
+
+ def test_setdefault(self):
+ data, wim = self._fixture()
+
+ o1 = self.Data()
+ oid1 = id(o1)
+
+ assert o1 not in wim
+
+ res1 = wim.setdefault(o1, 123)
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(res1, 123)
+ id_keys = set(wim.by_id.keys())
+
+ res2 = wim.setdefault(o1, 456)
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(res2, 123)
+ assert set(wim.by_id.keys()) == id_keys
+
+ del wim[o1]
+ assert o1 not in wim
+ assert oid1 not in wim.by_id
+ ne_(set(wim.by_id.keys()), id_keys)
+
+ res3 = wim.setdefault(o1, 789)
+ assert o1 in wim
+ assert oid1 in wim.by_id
+ eq_(res3, 789)
+ eq_(set(wim.by_id.keys()), id_keys)
+
+ def test_clear(self):
+ data, wim = self._fixture()
+
+ assert len(data) == len(wim) == len(wim.by_id)
+ wim.clear()
+
+ eq_(wim, {})
+ eq_(wim.by_id, {})
+
+ def test_update(self):
+ data, wim = self._fixture()
+ self.assertRaises(NotImplementedError, wim.update)
+
+ def test_weak_clear(self):
+ data, wim = self._fixture()
+
+ assert len(data) == len(wim) == len(wim.by_id)
+
+ del data[:]
+ eq_(wim, {})
+ eq_(wim.by_id, {})
+ eq_(wim._weakrefs, {})
+
+ def test_weak_single(self):
+ data, wim = self._fixture()
+
+ assert len(data) == len(wim) == len(wim.by_id)
+
+ oid = id(data[0])
+ del data[0]
+
+ assert len(data) == len(wim) == len(wim.by_id)
+ assert oid not in wim.by_id
+
+ def test_weak_threadhop(self):
+ data, wim = self._fixture()
+ data = set(data)
+
+ cv = threading.Condition()
+
+ def empty(obj):
+ cv.acquire()
+ obj.clear()
+ cv.notify()
+ cv.release()
+
+ th = threading.Thread(target=empty, args=(data,))
+
+ cv.acquire()
+ th.start()
+ cv.wait()
+ cv.release()
+
+ eq_(wim, {})
+ eq_(wim.by_id, {})
+ eq_(wim._weakrefs, {})
+
+
+class TestFormatArgspec(TestBase):
+ def test_specs(self):
+ def test(fn, wanted, grouped=None):
+ if grouped is None:
+ parsed = util.format_argspec_plus(fn)
+ else:
+ parsed = util.format_argspec_plus(fn, grouped=grouped)
+ eq_(parsed, wanted)
+
+ test(lambda: None,
+ {'args': '()', 'self_arg': None,
+ 'apply_kw': '()', 'apply_pos': '()' })
+
+ test(lambda: None,
+ {'args': '', 'self_arg': None,
+ 'apply_kw': '', 'apply_pos': '' },
+ grouped=False)
+
+ test(lambda self: None,
+ {'args': '(self)', 'self_arg': 'self',
+ 'apply_kw': '(self)', 'apply_pos': '(self)' })
+
+ test(lambda self: None,
+ {'args': 'self', 'self_arg': 'self',
+ 'apply_kw': 'self', 'apply_pos': 'self' },
+ grouped=False)
+
+ test(lambda *a: None,
+ {'args': '(*a)', 'self_arg': None,
+ 'apply_kw': '(*a)', 'apply_pos': '(*a)' })
+
+ test(lambda **kw: None,
+ {'args': '(**kw)', 'self_arg': None,
+ 'apply_kw': '(**kw)', 'apply_pos': '(**kw)' })
+
+ test(lambda *a, **kw: None,
+ {'args': '(*a, **kw)', 'self_arg': None,
+ 'apply_kw': '(*a, **kw)', 'apply_pos': '(*a, **kw)' })
+
+ test(lambda a, *b: None,
+ {'args': '(a, *b)', 'self_arg': 'a',
+ 'apply_kw': '(a, *b)', 'apply_pos': '(a, *b)' })
+
+ test(lambda a, **b: None,
+ {'args': '(a, **b)', 'self_arg': 'a',
+ 'apply_kw': '(a, **b)', 'apply_pos': '(a, **b)' })
+
+ test(lambda a, *b, **c: None,
+ {'args': '(a, *b, **c)', 'self_arg': 'a',
+ 'apply_kw': '(a, *b, **c)', 'apply_pos': '(a, *b, **c)' })
+
+ test(lambda a, b=1, **c: None,
+ {'args': '(a, b=1, **c)', 'self_arg': 'a',
+ 'apply_kw': '(a, b=b, **c)', 'apply_pos': '(a, b, **c)' })
+
+ test(lambda a=1, b=2: None,
+ {'args': '(a=1, b=2)', 'self_arg': 'a',
+ 'apply_kw': '(a=a, b=b)', 'apply_pos': '(a, b)' })
+
+ test(lambda a=1, b=2: None,
+ {'args': 'a=1, b=2', 'self_arg': 'a',
+ 'apply_kw': 'a=a, b=b', 'apply_pos': 'a, b' },
+ grouped=False)
+
+ def test_init_grouped(self):
+ object_spec = {
+ 'args': '(self)', 'self_arg': 'self',
+ 'apply_pos': '(self)', 'apply_kw': '(self)'}
+ wrapper_spec = {
+ 'args': '(self, *args, **kwargs)', 'self_arg': 'self',
+ 'apply_pos': '(self, *args, **kwargs)',
+ 'apply_kw': '(self, *args, **kwargs)'}
+ custom_spec = {
+ 'args': '(slef, a=123)', 'self_arg': 'slef', # yes, slef
+ 'apply_pos': '(slef, a)', 'apply_kw': '(slef, a=a)'}
+
+ self._test_init(None, object_spec, wrapper_spec, custom_spec)
+ self._test_init(True, object_spec, wrapper_spec, custom_spec)
+
+ def test_init_bare(self):
+ object_spec = {
+ 'args': 'self', 'self_arg': 'self',
+ 'apply_pos': 'self', 'apply_kw': 'self'}
+ wrapper_spec = {
+ 'args': 'self, *args, **kwargs', 'self_arg': 'self',
+ 'apply_pos': 'self, *args, **kwargs',
+ 'apply_kw': 'self, *args, **kwargs'}
+ custom_spec = {
+ 'args': 'slef, a=123', 'self_arg': 'slef', # yes, slef
+ 'apply_pos': 'slef, a', 'apply_kw': 'slef, a=a'}
+
+ self._test_init(False, object_spec, wrapper_spec, custom_spec)
+
+ def _test_init(self, grouped, object_spec, wrapper_spec, custom_spec):
+ def test(fn, wanted):
+ if grouped is None:
+ parsed = util.format_argspec_init(fn)
+ else:
+ parsed = util.format_argspec_init(fn, grouped=grouped)
+ eq_(parsed, wanted)
+
+ class O(object): pass
+
+ test(O.__init__, object_spec)
+
+ class O(object):
+ def __init__(self):
+ pass
+
+ test(O.__init__, object_spec)
+
+ class O(object):
+ def __init__(slef, a=123):
+ pass
+
+ test(O.__init__, custom_spec)
+
+ class O(list): pass
+
+ test(O.__init__, wrapper_spec)
+
+ class O(list):
+ def __init__(self, *args, **kwargs):
+ pass
+
+ test(O.__init__, wrapper_spec)
+
+ class O(list):
+ def __init__(self):
+ pass
+
+ test(O.__init__, object_spec)
+
+ class O(list):
+ def __init__(slef, a=123):
+ pass
+
+ test(O.__init__, custom_spec)
+
class AsInterfaceTest(TestBase):
+
class Something(object):
def _ignoreme(self): pass
def foo(self): pass
cls=self.Something, required=('foo'))
obj = self.Something()
- self.assertEqual(obj, util.as_interface(obj, cls=self.Something))
- self.assertEqual(obj, util.as_interface(obj, methods=('foo',)))
- self.assertEqual(
+ eq_(obj, util.as_interface(obj, cls=self.Something))
+ eq_(obj, util.as_interface(obj, methods=('foo',)))
+ eq_(
obj, util.as_interface(obj, cls=self.Something,
required=('outofband',)))
partial = self.Partial()
slotted.bar = lambda self: 123
for obj in partial, slotted:
- self.assertEqual(obj, util.as_interface(obj, cls=self.Something))
+ eq_(obj, util.as_interface(obj, cls=self.Something))
self.assertRaises(TypeError, util.as_interface, obj,
methods=('foo'))
- self.assertEqual(obj, util.as_interface(obj, methods=('bar',)))
- self.assertEqual(
- obj, util.as_interface(obj, cls=self.Something,
+ eq_(obj, util.as_interface(obj, methods=('bar',)))
+ eq_(obj, util.as_interface(obj, cls=self.Something,
required=('bar',)))
self.assertRaises(TypeError, util.as_interface, obj,
cls=self.Something, required=('foo',))
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.databases import firebird
-from sqlalchemy.exceptions import ProgrammingError
+from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql import table, column
from testlib import *
import testenv; testenv.configure_for_tests()
import StringIO, sys
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
from sqlalchemy.util import Decimal
from sqlalchemy.databases import maxdb
from testlib import *
finally:
try:
testing.db.execute("DROP TABLE dectest")
- except exceptions.DatabaseError:
+ except exc.DatabaseError:
pass
def test_decimal_fixed_serial(self):
finally:
try:
testing.db.execute("DROP TABLE assorted")
- except exceptions.DatabaseError:
+ except exc.DatabaseError:
pass
class DBAPITest(TestBase, AssertsExecutionResults):
import re
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from sqlalchemy.sql import table, column
from sqlalchemy.databases import mssql
from testlib import *
r = users.select(limit=3, offset=2,
order_by=[users.c.user_id]).execute().fetchall()
assert False # InvalidRequestError should have been raised
- except exceptions.InvalidRequestError:
+ except exc.InvalidRequestError:
pass
finally:
metadata.drop_all()
import testenv; testenv.configure_for_tests()
import sets
from sqlalchemy import *
-from sqlalchemy import sql, exceptions
+from sqlalchemy import sql, exc
from sqlalchemy.databases import mysql
from testlib import *
try:
enum_table.insert().execute(e1=None, e2=None, e3=None, e4=None)
self.assert_(False)
- except exceptions.SQLError:
+ except exc.SQLError:
self.assert_(True)
try:
enum_table.insert().execute(e1='c', e2='c', e3='c', e4='c')
self.assert_(False)
- except exceptions.InvalidRequestError:
+ except exc.InvalidRequestError:
self.assert_(True)
enum_table.insert().execute()
query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON thirdtable.userid = myothertable.otherid")
- self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
+ self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE thirdtable.userid(+) = myothertable.otherid AND mytable.myid = myothertable.otherid(+)", dialect=oracle.dialect(use_ansi=False))
query = table1.join(table2, table1.c.myid==table2.c.otherid).join(table3, table3.c.userid==table2.c.otherid)
- self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
+ self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE thirdtable.userid = myothertable.otherid AND mytable.myid = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
query = table1.join(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
self.assert_compile(query.select().order_by(table1.oid_column).limit(10).offset(5), "SELECT myid, name, description, otherid, othername, userid, \
mytable.description AS description, myothertable.otherid AS otherid, \
myothertable.othername AS othername, thirdtable.userid AS userid, \
thirdtable.otherstuff AS otherstuff, ROW_NUMBER() OVER (ORDER BY mytable.rowid) AS ora_rn \
-FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid(+) = myothertable.otherid) \
+FROM mytable, myothertable, thirdtable WHERE thirdtable.userid(+) = myothertable.otherid AND mytable.myid = myothertable.otherid) \
WHERE ora_rn>5 AND ora_rn<=15", dialect=oracle.dialect(use_ansi=False))
def test_alias_outer_join(self):
import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from sqlalchemy.databases import postgres
from sqlalchemy.engine.strategies import MockEngineStrategy
from testlib import *
try:
table.insert().execute({'data':'d2'})
assert False
- except exceptions.IntegrityError, e:
+ except exc.IntegrityError, e:
assert "violates not-null constraint" in str(e)
try:
table.insert().execute({'data':'d2'}, {'data':'d3'})
assert False
- except exceptions.IntegrityError, e:
+ except exc.IntegrityError, e:
assert "violates not-null constraint" in str(e)
table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
try:
table.insert().execute({'data':'d2'})
assert False
- except exceptions.IntegrityError, e:
+ except exc.IntegrityError, e:
assert "violates not-null constraint" in str(e)
try:
table.insert().execute({'data':'d2'}, {'data':'d3'})
assert False
- except exceptions.IntegrityError, e:
+ except exc.IntegrityError, e:
assert "violates not-null constraint" in str(e)
table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
try:
con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42')
con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0')
- except exceptions.SQLError, e:
+ except exc.SQLError, e:
if not "already exists" in str(e):
raise e
con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from sqlalchemy.databases import sqlite
from testlib import *
@testing.uses_deprecated('Using String type with no length')
def test_type_reflection(self):
# (ask_for, roundtripped_as_if_different)
- specs = [( String(), sqlite.SLText(), ),
+ specs = [( String(), sqlite.SLString(), ),
( String(1), sqlite.SLString(1), ),
( String(3), sqlite.SLString(3), ),
( Text(), sqlite.SLText(), ),
- ( Unicode(), sqlite.SLText(), ),
+ ( Unicode(), sqlite.SLString(), ),
( Unicode(1), sqlite.SLString(1), ),
( Unicode(3), sqlite.SLString(3), ),
( UnicodeText(), sqlite.SLText(), ),
for table in rt, rv:
for i, reflected in enumerate(table.c):
print reflected.type, type(expected[i])
- assert isinstance(reflected.type, type(expected[i]))
+ assert isinstance(reflected.type, type(expected[i])), type(expected[i])
finally:
db.execute('DROP VIEW types_v')
finally:
except:
try:
cx.execute('DROP TABLE tempy')
- except exceptions.DBAPIError:
+ except exc.DBAPIError:
pass
raise
@testing.exclude('sqlite', '<', (3, 4))
def test_empty_insert_pk2(self):
self.assertRaises(
- exceptions.DBAPIError,
+ exc.DBAPIError,
self._test_empty_insert,
Table('b', MetaData(testing.db),
Column('x', Integer, primary_key=True),
@testing.exclude('sqlite', '<', (3, 4))
def test_empty_insert_pk3(self):
self.assertRaises(
- exceptions.DBAPIError,
+ exc.DBAPIError,
self._test_empty_insert,
Table('c', MetaData(testing.db),
Column('x', Integer, primary_key=True),
including the deprecated versions of these arguments"""
import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import engine, exceptions
-from testlib import *
+from sqlalchemy import engine, exc
+from sqlalchemy import MetaData, ThreadLocalMetaData
+from testlib.sa import Table, Column, Integer, String, func, Sequence, text
+from testlib import TestBase, testing
class BindTest(TestBase):
try:
meth()
assert False
- except exceptions.UnboundExecutionError, e:
+ except exc.UnboundExecutionError, e:
self.assertEquals(
str(e),
"The MetaData "
try:
meth()
assert False
- except exceptions.UnboundExecutionError, e:
+ except exc.UnboundExecutionError, e:
self.assertEquals(
str(e),
"The Table 'test_table' "
@testing.future
def test_create_drop_err2(self):
+ metadata = MetaData()
+ table = Table('test_table', metadata,
+ Column('foo', Integer))
+
for meth in [
table.exists,
table.create,
try:
meth()
assert False
- except exceptions.UnboundExecutionError, e:
+ except exc.UnboundExecutionError, e:
self.assertEquals(
str(e),
"The Table 'test_table' "
assert e.bind is None
e.execute()
assert False
- except exceptions.UnboundExecutionError, e:
+ except exc.UnboundExecutionError, e:
assert str(e).endswith(
'is not bound and does not support direct '
'execution. Supply this statement to a Connection or '
try:
sess.flush()
assert False
- except exceptions.InvalidRequestError, e:
+ except exc.InvalidRequestError, e:
assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
finally:
if isinstance(bind, engine.Connection):
import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
from sqlalchemy.schema import DDL
-import sqlalchemy
-from testlib import *
+from sqlalchemy import create_engine
+from testlib.sa import MetaData, Table, Column, Integer, String
+import testlib.sa as tsa
+from testlib import TestBase, testing
class DDLEventTest(TestBase):
try:
r = eval(py)
assert False
- except exceptions.UnboundExecutionError:
+ except tsa.exc.UnboundExecutionError:
pass
for bind in engine, cx:
engine = create_engine(testing.db.name + '://',
strategy='mock', executor=executor)
engine.dialect.identifier_preparer = \
- sqlalchemy.sql.compiler.IdentifierPreparer(engine.dialect)
+ tsa.sql.compiler.IdentifierPreparer(engine.dialect)
return engine
def test_tokens(self):
ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
self.assertEquals(ddl._expand(sane_alone, bind), '-t-t')
- self.assertEquals(ddl._expand(sane_schema, bind), '"s"-t-s.t')
+ self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t')
self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
self.assertEquals(ddl._expand(insane_schema, bind),
'"s s"-"t t"-"s s"."t t"')
import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from testlib import *
+import re
+from sqlalchemy.interfaces import ConnectionProxy
+from testlib.sa import MetaData, Table, Column, Integer, String, INT, \
+ VARCHAR, func
+import testlib.sa as tsa
+from testlib import TestBase, testing, engines
+
+users, metadata = None, None
class ExecuteTest(TestBase):
def setUpAll(self):
global users, metadata
try:
conn.execute("osdjafioajwoejoasfjdoifjowejfoawejqoijwef")
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
assert True
+class ProxyConnectionTest(TestBase):
+ def test_proxy(self):
+
+ stmts = []
+ cursor_stmts = []
+
+ class MyProxy(ConnectionProxy):
+ def execute(self, conn, execute, clauseelement, *multiparams, **params):
+ stmts.append(
+ (str(clauseelement), params,multiparams)
+ )
+ return execute(clauseelement, *multiparams, **params)
+
+ def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+ cursor_stmts.append(
+ (statement, parameters, None)
+ )
+ return execute(cursor, statement, parameters, context)
+
+ def assert_stmts(expected, received):
+ for stmt, params, posn in expected:
+ if not received:
+ assert False
+ while received:
+ teststmt, testparams, testmultiparams = received.pop(0)
+ teststmt = re.compile(r'[\n\t ]+', re.M).sub(' ', teststmt).strip()
+ if teststmt.startswith(stmt) and (testparams==params or testparams==posn):
+ break
+
+ for engine in (
+ engines.testing_engine(options=dict(proxy=MyProxy())),
+ engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal'))
+ ):
+ m = MetaData(engine)
+
+ t1 = Table('t1', m, Column('c1', Integer, primary_key=True), Column('c2', String(50), default=func.lower('Foo'), primary_key=True))
+
+ m.create_all()
+ try:
+ t1.insert().execute(c1=5, c2='some data')
+ t1.insert().execute(c1=6)
+ assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')]
+ finally:
+ m.drop_all()
+
+ engine.dispose()
+
+ compiled = [
+ ("CREATE TABLE t1", {}, None),
+ ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, None),
+ ("INSERT INTO t1 (c1, c2)", {'c1': 6}, None),
+ ("select * from t1", {}, None),
+ ("DROP TABLE t1", {}, None)
+ ]
+
+ if engine.dialect.preexecute_pk_sequences:
+ cursor = [
+ ("CREATE TABLE t1", {}, None),
+ ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
+ ("SELECT lower", {'lower_2':'Foo'}, ['Foo']),
+ ("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']),
+ ("select * from t1", {}, None),
+ ("DROP TABLE t1", {}, None)
+ ]
+ else:
+ cursor = [
+ ("CREATE TABLE t1", {}, None),
+ ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
+ ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, [6, "Foo"]), # bind param name 'lower_2' might be incorrect
+ ("select * from t1", {}, None),
+ ("DROP TABLE t1", {}, None)
+ ]
+
+ assert_stmts(compiled, stmts)
+ assert_stmts(cursor, cursor_stmts)
+
+
if __name__ == "__main__":
testenv.main()
import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from testlib import *
import pickle
+from sqlalchemy import MetaData
+from testlib.sa import Table, Column, Integer, String, UniqueConstraint, \
+ CheckConstraint, ForeignKey
+import testlib.sa as tsa
+from testlib import TestBase, ComparesTables, testing
+
class MetaDataTest(TestBase, ComparesTables):
def test_metadata_connect(self):
t2 = Table('table1', metadata, Column('col1', Integer, primary_key=True),
Column('col2', String(20)))
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Table 'table1' is already defined for this MetaData instance. Specify 'useexisting=True' to redefine options and columns on an existing Table object."
finally:
metadata.drop_all()
meta.drop_all(testing.db)
def test_nonexistent(self):
- self.assertRaises(exceptions.NoSuchTableError, Table,
+ self.assertRaises(tsa.exc.NoSuchTableError, Table,
'fake_table',
MetaData(testing.db), autoload=True)
import testenv; testenv.configure_for_tests()
import ConfigParser, StringIO
-from sqlalchemy import *
-from sqlalchemy import exceptions, pool, engine
import sqlalchemy.engine.url as url
-from testlib import *
+from sqlalchemy import create_engine, engine_from_config
+import testlib.sa as tsa
+from testlib import TestBase
class ParseConnectTest(TestBase):
}
prefixed = dict(ini.items('prefixed'))
- self.assert_(engine._coerce_config(prefixed, 'sqlalchemy.') == expected)
+ self.assert_(tsa.engine._coerce_config(prefixed, 'sqlalchemy.') == expected)
plain = dict(ini.items('plain'))
- self.assert_(engine._coerce_config(plain, '') == expected)
+ self.assert_(tsa.engine._coerce_config(plain, '') == expected)
def test_engine_from_config(self):
dbapi = MockDBAPI()
try:
c = e.connect()
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
assert True
def test_urlattr(self):
assert e.pool._recycle == 50
# these args work for QueuePool
- e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=pool.QueuePool, module=MockDBAPI())
+ e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI())
try:
# but not SingletonThreadPool
- e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=pool.SingletonThreadPool)
+ e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool)
assert False
except TypeError:
assert True
import testenv; testenv.configure_for_tests()
-import threading, thread, time, gc
-import sqlalchemy.pool as pool
-import sqlalchemy.interfaces as interfaces
-import sqlalchemy.exceptions as exceptions
-from testlib import *
+import threading, time, gc
+from sqlalchemy import pool
+import testlib.sa as tsa
+from testlib import TestBase
mcid = 1
try:
c4 = p.connect()
assert False
- except exceptions.TimeoutError, e:
+ except tsa.exc.TimeoutError, e:
assert int(time.time() - now) == 2
def test_timeout_race(self):
now = time.time()
try:
c1 = p.connect()
- except exceptions.TimeoutError, e:
+ except tsa.exc.TimeoutError, e:
timeouts.append(int(time.time()) - now)
continue
time.sleep(4)
peaks.append(p.overflow())
con.close()
del con
- except exceptions.TimeoutError:
+ except tsa.exc.TimeoutError:
pass
threads = []
for i in xrange(thread_count):
# con can be None if invalidated
assert record is not None
self.checked_in.append(con)
- class ListenAll(interfaces.PoolListener, InstrumentingListener):
+ class ListenAll(tsa.interfaces.PoolListener, InstrumentingListener):
pass
class ListenConnect(InstrumentingListener):
def connect(self, con, record):
import testenv; testenv.configure_for_tests()
-import sys, weakref
-from sqlalchemy import create_engine, exceptions, select, MetaData, Table, Column, Integer, String
-from testlib import *
+import weakref
+from testlib.sa import select, MetaData, Table, Column, Integer, String
+import testlib.sa as tsa
+from testlib import TestBase, testing, engines
class MockDisconnect(Exception):
def close(self):
pass
+db, dbapi = None, None
class MockReconnectTest(TestBase):
def setUp(self):
global db, dbapi
dbapi = MockDBAPI()
# create engine using our current dburi
- db = create_engine('postgres://foo:bar@localhost/test', module=dbapi)
+ db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi)
# monkeypatch disconnect checker
db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
pass
# assert was invalidated
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
pass
# assert was invalidated
try:
conn.execute(select([1]))
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
assert trans.is_active
try:
trans.commit()
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
assert trans.is_active
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError:
+ except tsa.exc.DBAPIError:
pass
assert not conn.closed
assert not conn.invalidated
assert len(dbapi.connections) == 1
-
+engine = None
class RealReconnectTest(TestBase):
def setUp(self):
global engine
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
assert conn.invalidated
assert not conn.invalidated
conn.close()
-
+
def test_close(self):
conn = engine.connect()
self.assertEquals(conn.execute(select([1])).scalar(), 1)
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
try:
conn.execute(select([1]))
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
try:
conn.execute(select([1]))
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
assert trans.is_active
try:
trans.commit()
assert False
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
assert str(e) == "Can't reconnect until invalid transaction is rolled back"
assert trans.is_active
self.assertEquals(conn.execute(select([1])).scalar(), 1)
assert not conn.invalidated
+meta, table, engine = None, None, None
class InvalidateDuringResultTest(TestBase):
def setUp(self):
global meta, table, engine
table.insert().execute(
[{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
)
-
+
def tearDown(self):
meta.drop_all()
engine.dispose()
-
- @testing.fails_on('mysql')
+
+ @testing.fails_on('mysql')
def test_invalidate_on_results(self):
conn = engine.connect()
-
+
result = conn.execute("select * from sometable")
for x in xrange(20):
result.fetchone()
-
+
engine.test_shutdown()
try:
result.fetchone()
assert False
- except exceptions.DBAPIError, e:
+ except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
raise
assert conn.invalidated
-
+
if __name__ == '__main__':
testenv.main()
import testenv; testenv.configure_for_tests()
import StringIO, unicodedata
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from sqlalchemy import types as sqltypes
-from testlib import *
-from testlib import engines
+import sqlalchemy as sa
+from testlib.sa import MetaData, Table, Column
+from testlib import TestBase, ComparesTables, testing, engines, sa as tsa
+from testlib.compat import set
+metadata, users = None, None
+
class ReflectionTest(TestBase, ComparesTables):
@testing.exclude('mysql', '<', (4, 1, 1))
meta = MetaData(testing.db)
users = Table('engine_users', meta,
- Column('user_id', INT, primary_key=True),
- Column('user_name', VARCHAR(20), nullable=False),
- Column('test1', CHAR(5), nullable=False),
- Column('test2', Float(5), nullable=False),
- Column('test3', Text),
- Column('test4', Numeric, nullable = False),
- Column('test5', DateTime),
- Column('parent_user_id', Integer, ForeignKey('engine_users.user_id')),
- Column('test6', DateTime, nullable=False),
- Column('test7', Text),
- Column('test8', Binary),
- Column('test_passivedefault2', Integer, PassiveDefault("5")),
- Column('test9', Binary(100)),
- Column('test_numeric', Numeric()),
+ Column('user_id', sa.INT, primary_key=True),
+ Column('user_name', sa.VARCHAR(20), nullable=False),
+ Column('test1', sa.CHAR(5), nullable=False),
+ Column('test2', sa.Float(5), nullable=False),
+ Column('test3', sa.Text),
+ Column('test4', sa.Numeric, nullable = False),
+ Column('test5', sa.DateTime),
+ Column('parent_user_id', sa.Integer,
+ sa.ForeignKey('engine_users.user_id')),
+ Column('test6', sa.DateTime, nullable=False),
+ Column('test7', sa.Text),
+ Column('test8', sa.Binary),
+ Column('test_passivedefault2', sa.Integer, sa.PassiveDefault("5")),
+ Column('test9', sa.Binary(100)),
+ Column('test_numeric', sa.Numeric()),
test_needs_fk=True,
)
addresses = Table('engine_email_addresses', meta,
- Column('address_id', Integer, primary_key = True),
- Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
- Column('email_address', String(20)),
+ Column('address_id', sa.Integer, primary_key = True),
+ Column('remote_user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+ Column('email_address', sa.String(20)),
test_needs_fk=True,
)
meta.create_all()
try:
meta2 = MetaData()
- reflected_users = Table('engine_users', meta2, autoload=True, autoload_with=testing.db)
- reflected_addresses = Table('engine_email_addresses', meta2, autoload=True, autoload_with=testing.db)
+ reflected_users = Table('engine_users', meta2, autoload=True,
+ autoload_with=testing.db)
+ reflected_addresses = Table('engine_email_addresses', meta2,
+ autoload=True, autoload_with=testing.db)
self.assert_tables_equal(users, reflected_users)
self.assert_tables_equal(addresses, reflected_addresses)
finally:
def test_include_columns(self):
meta = MetaData(testing.db)
- foo = Table('foo', meta, *[Column(n, String(30)) for n in ['a', 'b', 'c', 'd', 'e', 'f']])
+ foo = Table('foo', meta, *[Column(n, sa.String(30))
+ for n in ['a', 'b', 'c', 'd', 'e', 'f']])
meta.create_all()
try:
meta2 = MetaData(testing.db)
- foo = Table('foo', meta2, autoload=True, include_columns=['b', 'f', 'e'])
+ foo = Table('foo', meta2, autoload=True,
+ include_columns=['b', 'f', 'e'])
# test that cols come back in original order
self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
for c in ('b', 'f', 'e'):
assert c in foo.c
for c in ('a', 'c', 'd'):
assert c not in foo.c
-
+
# test against a table which is already reflected
meta3 = MetaData(testing.db)
foo = Table('foo', meta3, autoload=True)
- foo = Table('foo', meta3, include_columns=['b', 'f', 'e'], useexisting=True)
+ foo = Table('foo', meta3, include_columns=['b', 'f', 'e'],
+ useexisting=True)
self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
for c in ('b', 'f', 'e'):
assert c in foo.c
def test_unknown_types(self):
meta = MetaData(testing.db)
t = Table("test", meta,
- Column('foo', DateTime))
+ Column('foo', sa.DateTime))
import sys
dialect_module = sys.modules[testing.db.dialect.__module__]
m2 = MetaData(testing.db)
t2 = Table("test", m2, autoload=True)
assert False
- except exceptions.SAWarning:
+ except tsa.exc.SAWarning:
assert True
@testing.emits_warning('Did not recognize type')
def warns():
m3 = MetaData(testing.db)
t3 = Table("test", m3, autoload=True)
- assert t3.c.foo.type.__class__ == sqltypes.NullType
+ assert t3.c.foo.type.__class__ == sa.types.NullType
finally:
dialect_module.ischema_names = ischema_names
meta = MetaData(testing.db)
table = Table(
'override_test', meta,
- Column('col1', Integer, primary_key=True),
- Column('col2', String(20)),
- Column('col3', Numeric)
+ Column('col1', sa.Integer, primary_key=True),
+ Column('col2', sa.String(20)),
+ Column('col3', sa.Numeric)
)
table.create()
try:
table = Table(
'override_test', meta2,
- Column('col2', Unicode()),
- Column('col4', String(30)), autoload=True)
+ Column('col2', sa.Unicode()),
+ Column('col4', sa.String(30)), autoload=True)
- self.assert_(isinstance(table.c.col1.type, Integer))
- self.assert_(isinstance(table.c.col2.type, Unicode))
- self.assert_(isinstance(table.c.col4.type, String))
+ self.assert_(isinstance(table.c.col1.type, sa.Integer))
+ self.assert_(isinstance(table.c.col2.type, sa.Unicode))
+ self.assert_(isinstance(table.c.col4.type, sa.String))
finally:
table.drop()
meta = MetaData(testing.db)
users = Table('users', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(30)))
+ Column('id', sa.Integer, primary_key=True),
+ Column('name', sa.String(30)))
addresses = Table('addresses', meta,
- Column('id', Integer, primary_key=True),
- Column('street', String(30)))
+ Column('id', sa.Integer, primary_key=True),
+ Column('street', sa.String(30)))
meta.create_all()
try:
meta2 = MetaData(testing.db)
a2 = Table('addresses', meta2,
- Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+ Column('id', sa.Integer,
+ sa.ForeignKey('users.id'), primary_key=True),
autoload=True)
u2 = Table('users', meta2, autoload=True)
meta3 = MetaData(testing.db)
u3 = Table('users', meta3, autoload=True)
a3 = Table('addresses', meta3,
- Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+ Column('id', sa.Integer, sa.ForeignKey('users.id'),
+ primary_key=True),
autoload=True)
assert list(a3.primary_key) == [a3.c.id]
meta = MetaData(testing.db)
users = Table('users', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(30)))
+ Column('id', sa.Integer, primary_key=True),
+ Column('name', sa.String(30)))
addresses = Table('addresses', meta,
- Column('id', Integer, primary_key=True),
- Column('street', String(30)),
- Column('user_id', Integer))
+ Column('id', sa.Integer, primary_key=True),
+ Column('street', sa.String(30)),
+ Column('user_id', sa.Integer))
meta.create_all()
try:
meta2 = MetaData(testing.db)
a2 = Table('addresses', meta2,
- Column('user_id', Integer, ForeignKey('users.id')),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
autoload=True)
u2 = Table('users', meta2, autoload=True)
meta3 = MetaData(testing.db)
u3 = Table('users', meta3, autoload=True)
a3 = Table('addresses', meta3,
- Column('user_id', Integer, ForeignKey('users.id')),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
autoload=True)
assert u3.join(a3).onclause == u3.c.id==a3.c.user_id
meta4 = MetaData(testing.db)
u4 = Table('users', meta4,
- Column('id', Integer, key='u_id', primary_key=True),
+ Column('id', sa.Integer, key='u_id', primary_key=True),
autoload=True)
a4 = Table('addresses', meta4,
- Column('id', Integer, key='street', primary_key=True),
- Column('street', String(30), key='user_id'),
- Column('user_id', Integer, ForeignKey('users.u_id'),
+ Column('id', sa.Integer, key='street', primary_key=True),
+ Column('street', sa.String(30), key='user_id'),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.u_id'),
key='id'),
autoload=True)
meta = MetaData(testing.db)
users = Table('users', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(30)),
+ Column('id', sa.Integer, primary_key=True),
+ Column('name', sa.String(30)),
test_needs_fk=True)
addresses = Table('addresses', meta,
- Column('id', Integer,primary_key=True),
- Column('user_id', Integer, ForeignKey('users.id')),
+ Column('id', sa.Integer, primary_key=True),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
test_needs_fk=True)
meta.create_all()
try:
meta2 = MetaData(testing.db)
a2 = Table('addresses', meta2,
- Column('user_id',Integer, ForeignKey('users.id')),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
autoload=True)
u2 = Table('users', meta2, autoload=True)
meta2 = MetaData(testing.db)
u2 = Table('users', meta2,
- Column('id', Integer, primary_key=True),
+ Column('id', sa.Integer, primary_key=True),
autoload=True)
a2 = Table('addresses', meta2,
- Column('id', Integer, primary_key=True),
- Column('user_id',Integer, ForeignKey('users.id')),
+ Column('id', sa.Integer, primary_key=True),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
autoload=True)
assert len(a2.foreign_keys) == 1
assert u2.join(a2).onclause == u2.c.id==a2.c.user_id
finally:
meta.drop_all()
-
+
def test_use_existing(self):
meta = MetaData(testing.db)
users = Table('users', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(30)),
+ Column('id', sa.Integer, primary_key=True),
+ Column('name', sa.String(30)),
test_needs_fk=True)
addresses = Table('addresses', meta,
- Column('id', Integer,primary_key=True),
- Column('user_id', Integer, ForeignKey('users.id')),
- Column('data', String(100)),
+ Column('id', sa.Integer,primary_key=True),
+ Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
+ Column('data', sa.String(100)),
test_needs_fk=True)
meta.create_all()
try:
meta2 = MetaData(testing.db)
- addresses = Table('addresses', meta2, Column('data', Unicode), autoload=True)
+ addresses = Table('addresses', meta2, Column('data', sa.Unicode), autoload=True)
try:
- users = Table('users', meta2, Column('name', Unicode), autoload=True)
+ users = Table('users', meta2, Column('name', sa.Unicode), autoload=True)
assert False
- except exceptions.InvalidRequestError, err:
+ except tsa.exc.InvalidRequestError, err:
assert str(err) == "Table 'users' is already defined for this MetaData instance. Specify 'useexisting=True' to redefine options and columns on an existing Table object."
- users = Table('users', meta2, Column('name', Unicode), autoload=True, useexisting=True)
- assert isinstance(users.c.name.type, Unicode)
+ users = Table('users', meta2, Column('name', sa.Unicode), autoload=True, useexisting=True)
+ assert isinstance(users.c.name.type, sa.Unicode)
assert not users.quote
try:
metadata = MetaData(bind=testing.db)
book = Table('book', metadata, autoload=True)
- assert book.c.id in book.primary_key
- assert book.c.series not in book.primary_key
+ assert book.primary_key.contains_column(book.c.id)
+ assert not book.primary_key.contains_column(book.c.series)
assert len(book.primary_key) == 1
finally:
testing.db.execute("drop table book")
def test_fk_error(self):
metadata = MetaData(testing.db)
slots_table = Table('slots', metadata,
- Column('slot_id', Integer, primary_key=True),
- Column('pkg_id', Integer, ForeignKey('pkgs.pkg_id')),
- Column('slot', String(128)),
+ Column('slot_id', sa.Integer, primary_key=True),
+ Column('pkg_id', sa.Integer, sa.ForeignKey('pkgs.pkg_id')),
+ Column('slot', sa.String(128)),
)
try:
metadata.create_all()
assert False
- except exceptions.InvalidRequestError, err:
+ except tsa.exc.InvalidRequestError, err:
assert str(err) == "Could not find table 'pkgs' with which to generate a foreign key"
def test_composite_pks(self):
try:
metadata = MetaData(bind=testing.db)
book = Table('book', metadata, autoload=True)
- assert book.c.id in book.primary_key
- assert book.c.isbn in book.primary_key
- assert book.c.series not in book.primary_key
+ assert book.primary_key.contains_column(book.c.id)
+ assert book.primary_key.contains_column(book.c.isbn)
+ assert not book.primary_key.contains_column(book.c.series)
assert len(book.primary_key) == 2
finally:
testing.db.execute("drop table book")
meta = MetaData(testing.db)
multi = Table(
'multi', meta,
- Column('multi_id', Integer, primary_key=True),
- Column('multi_rev', Integer, primary_key=True),
- Column('multi_hoho', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('val', String(100)),
+ Column('multi_id', sa.Integer, primary_key=True),
+ Column('multi_rev', sa.Integer, primary_key=True),
+ Column('multi_hoho', sa.Integer, primary_key=True),
+ Column('name', sa.String(50), nullable=False),
+ Column('val', sa.String(100)),
test_needs_fk=True,
)
multi2 = Table('multi2', meta,
- Column('id', Integer, primary_key=True),
- Column('foo', Integer),
- Column('bar', Integer),
- Column('lala', Integer),
- Column('data', String(50)),
- ForeignKeyConstraint(['foo', 'bar', 'lala'], ['multi.multi_id', 'multi.multi_rev', 'multi.multi_hoho']),
+ Column('id', sa.Integer, primary_key=True),
+ Column('foo', sa.Integer),
+ Column('bar', sa.Integer),
+ Column('lala', sa.Integer),
+ Column('data', sa.String(50)),
+ sa.ForeignKeyConstraint(['foo', 'bar', 'lala'], ['multi.multi_id', 'multi.multi_rev', 'multi.multi_hoho']),
test_needs_fk=True,
)
meta.create_all()
table2 = Table('multi2', meta2, autoload=True, autoload_with=testing.db)
self.assert_tables_equal(multi, table)
self.assert_tables_equal(multi2, table2)
- j = join(table, table2)
- self.assert_(and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar, table.c.multi_hoho==table2.c.lala).compare(j.onclause))
+ j = sa.join(table, table2)
+ self.assert_(sa.and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar, table.c.multi_hoho==table2.c.lala).compare(j.onclause))
finally:
meta.drop_all()
# check a table that uses an SQL reserved name doesn't cause an error
meta = MetaData(testing.db)
table_a = Table('select', meta,
- Column('not', Integer, primary_key=True),
- Column('from', String(12), nullable=False),
- UniqueConstraint('from', name='when'))
- Index('where', table_a.c['from'])
+ Column('not', sa.Integer, primary_key=True),
+ Column('from', sa.String(12), nullable=False),
+ sa.UniqueConstraint('from', name='when'))
+ sa.Index('where', table_a.c['from'])
# There's currently no way to calculate identifier case normalization
# in isolation, so...
quoter = meta.bind.dialect.identifier_preparer.quote_identifier
table_b = Table('false', meta,
- Column('create', Integer, primary_key=True),
- Column('true', Integer, ForeignKey('select.not')),
- CheckConstraint('%s <> 1' % quoter(check_col),
+ Column('create', sa.Integer, primary_key=True),
+ Column('true', sa.Integer, sa.ForeignKey('select.not')),
+ sa.CheckConstraint('%s <> 1' % quoter(check_col),
name='limit'))
table_c = Table('is', meta,
- Column('or', Integer, nullable=False, primary_key=True),
- Column('join', Integer, nullable=False, primary_key=True),
- PrimaryKeyConstraint('or', 'join', name='to'))
+ Column('or', sa.Integer, nullable=False, primary_key=True),
+ Column('join', sa.Integer, nullable=False, primary_key=True),
+ sa.PrimaryKeyConstraint('or', 'join', name='to'))
- index_c = Index('else', table_c.c.join)
+ index_c = sa.Index('else', table_c.c.join)
meta.create_all()
baseline = MetaData(testing.db)
for name in names:
- Table(name, baseline, Column('id', Integer, primary_key=True))
+ Table(name, baseline, Column('id', sa.Integer, primary_key=True))
baseline.create_all()
try:
try:
m4.reflect(only=['rt_a', 'rt_f'])
self.assert_(False)
- except exceptions.InvalidRequestError, e:
+ except tsa.exc.InvalidRequestError, e:
self.assert_(e.args[0].endswith('(rt_f)'))
m5 = MetaData(testing.db)
try:
m8 = MetaData(reflect=True)
self.assert_(False)
- except exceptions.ArgumentError, e:
+ except tsa.exc.ArgumentError, e:
self.assert_(
e.args[0] ==
"A bind must be supplied in conjunction with reflect=True")
global metadata, users
metadata = MetaData()
users = Table('users', metadata,
- Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True),
- Column('user_name', String(40)),
+ Column('user_id', sa.Integer, sa.Sequence('user_id_seq', optional=True), primary_key=True),
+ Column('user_name', sa.String(40)),
)
addresses = Table('email_addresses', metadata,
- Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
- Column('user_id', Integer, ForeignKey(users.c.user_id)),
- Column('email_address', String(40)),
+ Column('address_id', sa.Integer, sa.Sequence('address_id_seq', optional=True), primary_key = True),
+ Column('user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+ Column('email_address', sa.String(40)),
)
orders = Table('orders', metadata,
- Column('order_id', Integer, Sequence('order_id_seq', optional=True), primary_key = True),
- Column('user_id', Integer, ForeignKey(users.c.user_id)),
- Column('description', String(50)),
- Column('isopen', Integer),
+ Column('order_id', sa.Integer, sa.Sequence('order_id_seq', optional=True), primary_key = True),
+ Column('user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+ Column('description', sa.String(50)),
+ Column('isopen', sa.Integer),
)
orderitems = Table('items', metadata,
- Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
- Column('order_id', INT, ForeignKey("orders")),
- Column('item_name', VARCHAR(50)),
+ Column('item_id', sa.INT, sa.Sequence('items_id_seq', optional=True), primary_key = True),
+ Column('order_id', sa.INT, sa.ForeignKey("orders")),
+ Column('item_name', sa.VARCHAR(50)),
)
def test_sorter( self ):
def test_append_constraint_unique(self):
meta = MetaData()
- users = Table('users', meta, Column('id', Integer))
- addresses = Table('addresses', meta, Column('id', Integer), Column('user_id', Integer))
+ users = Table('users', meta, Column('id', sa.Integer))
+ addresses = Table('addresses', meta, Column('id', sa.Integer), Column('user_id', sa.Integer))
- fk = ForeignKeyConstraint(['user_id'],[users.c.id])
+ fk = sa.ForeignKeyConstraint(['user_id'],[users.c.id])
addresses.append_constraint(fk)
addresses.append_constraint(fk)
names = set([u'plain', u'Unit\u00e9ble', u'\u6e2c\u8a66'])
for name in names:
- Table(name, metadata, Column('id', Integer, Sequence(name + "_id_seq"), primary_key=True))
+ Table(name, metadata, Column('id', sa.Integer, sa.Sequence(name + "_id_seq"), primary_key=True))
metadata.create_all()
reflected = set(bind.table_names())
def test_iteration(self):
metadata = MetaData()
table1 = Table('table1', metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', sa.Integer, primary_key=True),
schema='someschema')
table2 = Table('table2', metadata,
- Column('col1', Integer, primary_key=True),
- Column('col2', Integer, ForeignKey('someschema.table1.col1')),
+ Column('col1', sa.Integer, primary_key=True),
+ Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')),
schema='someschema')
# ensure this doesnt crash
print [t for t in metadata.table_iterator()]
buf = StringIO.StringIO()
def foo(s, p=None):
buf.write(s)
- gen = create_engine(testing.db.name + "://", strategy="mock", executor=foo)
+ gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo)
gen = gen.dialect.schemagenerator(gen.dialect, gen)
gen.traverse(table1)
gen.traverse(table2)
metadata = MetaData(engine)
table1 = Table('table1', metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', sa.Integer, primary_key=True),
schema=schema)
table2 = Table('table2', metadata,
- Column('col1', Integer, primary_key=True),
- Column('col2', Integer,
- ForeignKey('%s.table1.col1' % schema)),
+ Column('col1', sa.Integer, primary_key=True),
+ Column('col2', sa.Integer,
+ sa.ForeignKey('%s.table1.col1' % schema)),
schema=schema)
try:
metadata.create_all()
global metadata, users
metadata = MetaData()
users = Table('users', metadata,
- Column('user_id', Integer, Sequence('user_id_seq'), primary_key=True),
- Column('user_name', String(40)),
+ Column('user_id', sa.Integer, sa.Sequence('user_id_seq'), primary_key=True),
+ Column('user_name', sa.String(40)),
)
@testing.unsupported('sqlite', 'mysql', 'mssql', 'access', 'sybase')
import testenv; testenv.configure_for_tests()
import sys, time, threading
-
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from testlib import *
+from testlib.sa import create_engine, MetaData, Table, Column, INT, VARCHAR, \
+ Sequence, select, Integer, String, func, text
+from testlib import TestBase, testing
+users, metadata = None, None
class TransactionTest(TestBase):
def setUpAll(self):
global users, metadata
def tearDownAll(self):
users.drop(testing.db)
- def testcommits(self):
+ def test_commits(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
assert len(result.fetchall()) == 3
transaction.commit()
- def testrollback(self):
+ def test_rollback(self):
"""test a basic rollback"""
connection = testing.db.connect()
transaction = connection.begin()
assert len(result.fetchall()) == 0
connection.close()
- def testraise(self):
+ def test_raise(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testnestedrollback(self):
+ def test_nested_rollback(self):
connection = testing.db.connect()
try:
@testing.exclude('mysql', '<', (5, 0, 3))
- def testnesting(self):
+ def test_nesting(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testclose(self):
+ def test_close(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testclose2(self):
+ def test_close2(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
assert len(result.fetchall()) == 0
connection.close()
-
- @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testnestedsubtransactionrollback(self):
+ @testing.requires.savepoints
+ def test_nested_subtransaction_rollback(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testnestedsubtransactioncommit(self):
+ @testing.requires.savepoints
+ def test_nested_subtransaction_commit(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testrollbacktosubtransaction(self):
+ @testing.requires.savepoints
+ def test_rollback_to_subtransaction(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name='user1')
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testtwophasetransaction(self):
+ @testing.requires.two_phase_transactions
+ def test_two_phase_transaction(self):
connection = testing.db.connect()
transaction = connection.begin_twophase()
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testmixedtwophasetransaction(self):
+ @testing.requires.two_phase_transactions
+ @testing.requires.savepoints
+ def test_mixed_two_phase_transaction(self):
connection = testing.db.connect()
transaction = connection.begin_twophase()
)
connection.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- # fixme: see if this is still true and/or can be convert to fails_on()
- @testing.unsupported('mysql')
- def testtwophaserecover(self):
+ @testing.requires.two_phase_transactions
+ @testing.fails_on('mysql')
+ def test_two_phase_recover(self):
# MySQL recovery doesn't currently seem to work correctly
# Prepared transactions disappear when connections are closed and even
# when they aren't it doesn't seem possible to use the recovery id.
)
connection2.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testmultipletwophase(self):
+ @testing.requires.two_phase_transactions
+ def test_multiple_two_phase(self):
conn = testing.db.connect()
xa = conn.begin_twophase()
metadata.drop_all(testing.db)
@testing.unsupported('sqlite')
- def testrollback_deadlock(self):
+ def test_rollback_deadlock(self):
"""test that returning connections to the pool clears any object locks."""
conn1 = testing.db.connect()
conn2 = testing.db.connect()
users.drop(conn2)
conn2.close()
+foo = None
class ExplicitAutoCommitTest(TestBase):
- """test the 'autocommit' flag on select() and text() objects.
-
+ """test the 'autocommit' flag on select() and text() objects.
+
Requires Postgres so that we may define a custom function which modifies the database.
"""
-
+
__only_on__ = 'postgres'
def setUpAll(self):
def tearDown(self):
foo.delete().execute()
-
+
def tearDownAll(self):
testing.db.execute("drop function insert_foo(varchar)")
metadata.drop_all()
-
+
def test_control(self):
- # test that not using autocommit does not commit
+ # test that not using autocommit does not commit
conn1 = testing.db.connect()
conn2 = testing.db.connect()
trans.commit()
assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',), ('moredata',)]
-
+
conn1.close()
conn2.close()
-
+
def test_explicit_compiled(self):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
-
+
conn1.execute(select([func.insert_foo('data1')], autocommit=True))
assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',)]
conn1.execute(select([func.insert_foo('data2')]).autocommit())
assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',), ('data2',)]
-
+
conn1.close()
conn2.close()
-
+
def test_explicit_text(self):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
-
+
conn1.execute(text("select insert_foo('moredata')", autocommit=True))
assert conn2.execute(select([foo.c.data])).fetchall() == [('moredata',)]
-
+
conn1.close()
conn2.close()
def test_implicit_text(self):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
-
+
conn1.execute(text("insert into foo (data) values ('implicitdata')"))
assert conn2.execute(select([foo.c.data])).fetchall() == [('implicitdata',)]
-
+
conn1.close()
conn2.close()
-
-
+
+
+tlengine = None
class TLTransactionTest(TestBase):
def setUpAll(self):
global users, metadata, tlengine
finally:
external_connection.close()
- def testrollback(self):
+ def test_rollback(self):
"""test a basic rollback"""
tlengine.begin()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
finally:
external_connection.close()
- def testcommit(self):
+ def test_commit(self):
"""test a basic commit"""
tlengine.begin()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
finally:
external_connection.close()
- def testcommits(self):
+ def test_commits(self):
assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0
connection = tlengine.contextual_connect()
assert len(l) == 3, "expected 3 got %d" % len(l)
transaction.commit()
- def testrollback_off_conn(self):
+ def test_rollback_off_conn(self):
# test that a TLTransaction opened off a TLConnection allows that
# TLConnection to be aware of the transactional context
conn = tlengine.contextual_connect()
finally:
external_connection.close()
- def testmorerollback_off_conn(self):
+ def test_morerollback_off_conn(self):
# test that an existing TLConnection automatically takes place in a TLTransaction
# opened on a second TLConnection
conn = tlengine.contextual_connect()
finally:
external_connection.close()
- def testcommit_off_conn(self):
+ def test_commit_off_connection(self):
conn = tlengine.contextual_connect()
trans = conn.begin()
conn.execute(users.insert(), user_id=1, user_name='user1')
@testing.unsupported('sqlite')
@testing.exclude('mysql', '<', (5, 0, 3))
- def testnesting(self):
+ def test_nesting(self):
"""tests nesting of transactions"""
external_connection = tlengine.connect()
self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
external_connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testmixednesting(self):
+ def test_mixed_nesting(self):
"""tests nesting of transactions off the TLEngine directly inside of
tranasctions off the connection from the TLEngine"""
external_connection = tlengine.connect()
external_connection.close()
@testing.exclude('mysql', '<', (5, 0, 3))
- def testmoremixednesting(self):
+ def test_more_mixed_nesting(self):
"""tests nesting of transactions off the connection from the TLEngine
inside of tranasctions off thbe TLEngine directly."""
external_connection = tlengine.connect()
finally:
external_connection.close()
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testsessionnesting(self):
- class User(object):
- pass
- try:
- mapper(User, users)
-
- sess = create_session(bind=tlengine)
- tlengine.begin()
- u = User()
- sess.save(u)
- sess.flush()
- tlengine.commit()
- finally:
- clear_mappers()
- def testconnections(self):
+ def test_connections(self):
"""tests that contextual_connect is threadlocal"""
c1 = tlengine.contextual_connect()
c2 = tlengine.contextual_connect()
c2.close()
assert c1.connection.connection is not None
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
- def testtwophasetransaction(self):
+ @testing.requires.two_phase_transactions
+ def test_two_phase_transaction(self):
tlengine.begin_twophase()
tlengine.execute(users.insert(), user_id=1, user_name='user1')
tlengine.prepare()
[(1,),(2,)]
)
+counters = None
class ForUpdateTest(TestBase):
def setUpAll(self):
global counters, metadata
@testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access')
- def testqueued_update(self):
+ def test_queued_update(self):
"""Test SELECT FOR UPDATE with concurrent modifications.
Runs concurrent modifications on a single row in the users table,
return errors
@testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access')
- def testqueued_select(self):
+ def test_queued_select(self):
"""Simple SELECT FOR UPDATE conflict test"""
errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)])
@testing.unsupported('sqlite', 'mysql', 'mssql', 'firebird',
'sybase', 'access')
- def testnowait_select(self):
+ def test_nowait_select(self):
"""Simple SELECT FOR UPDATE NOWAIT conflict test"""
errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)],
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from datetime import datetime
-
-from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore
-from sqlalchemy import and_, or_, exceptions
-from sqlalchemy import ForeignKey, String, Integer, DateTime, Table, Column
-from sqlalchemy.orm import clear_mappers, backref, create_session, class_mapper
-import sqlalchemy.ext.activemapper as activemapper
-import sqlalchemy
-from testlib import *
-
-
-class testcase(TestBase):
- def setUpAll(self):
- clear_mappers()
- objectstore.clear()
- global Person, Preferences, Address
-
- class Person(ActiveMapper):
- class mapping:
- __version_id_col__ = 'row_version'
- full_name = column(String(128))
- first_name = column(String(128))
- middle_name = column(String(128))
- last_name = column(String(128))
- birth_date = column(DateTime)
- ssn = column(String(128))
- gender = column(String(128))
- home_phone = column(String(128))
- cell_phone = column(String(128))
- work_phone = column(String(128))
- row_version = column(Integer, default=0)
- prefs_id = column(Integer, foreign_key=ForeignKey('preferences.id'))
- addresses = one_to_many('Address', colname='person_id', backref='person', order_by=['state', 'city', 'postal_code'])
- preferences = one_to_one('Preferences', colname='pref_id', backref='person')
-
- def __str__(self):
- s = '%s\n' % self.full_name
- s += ' * birthdate: %s\n' % (self.birth_date or 'not provided')
- s += ' * fave color: %s\n' % (self.preferences.favorite_color or 'Unknown')
- s += ' * personality: %s\n' % (self.preferences.personality_type or 'Unknown')
-
- for address in self.addresses:
- s += ' * address: %s\n' % address.address_1
- s += ' %s, %s %s\n' % (address.city, address.state, address.postal_code)
-
- return s
-
- class Preferences(ActiveMapper):
- class mapping:
- __table__ = 'preferences'
- favorite_color = column(String(128))
- personality_type = column(String(128))
-
- class Address(ActiveMapper):
- class mapping:
- # note that in other objects, the 'id' primary key is
- # automatically added -- if you specify a primary key,
- # then ActiveMapper will not add an integer primary key
- # for you.
- id = column(Integer, primary_key=True)
- type = column(String(128))
- address_1 = column(String(128))
- city = column(String(128))
- state = column(String(128))
- postal_code = column(String(128))
- person_id = column(Integer, foreign_key=ForeignKey('person.id'))
-
- activemapper.metadata.bind = testing.db
- activemapper.create_tables()
-
- def tearDownAll(self):
- clear_mappers()
- activemapper.drop_tables()
-
- def tearDown(self):
- for t in activemapper.metadata.table_iterator(reverse=True):
- t.delete().execute()
-
- def create_person_one(self):
- # create a person
- p1 = Person(
- full_name='Jonathan LaCour',
- birth_date=datetime(1979, 10, 12),
- preferences=Preferences(
- favorite_color='Green',
- personality_type='ENTP'
- ),
- addresses=[
- Address(
- address_1='123 Some Great Road.',
- city='Atlanta',
- state='GA',
- postal_code='30338'
- ),
- Address(
- address_1='435 Franklin Road.',
- city='Atlanta',
- state='GA',
- postal_code='30342'
- )
- ]
- )
- return p1
-
-
- def create_person_two(self):
- p2 = Person(
- full_name='Lacey LaCour',
- addresses=[
- Address(
- address_1='123 Some Great Road.',
- city='Atlanta',
- state='GA',
- postal_code='30338'
- ),
- Address(
- address_1='200 Main Street',
- city='Roswell',
- state='GA',
- postal_code='30075'
- )
- ]
- )
- # I don't like that I have to do this... and putting
- # a "self.preferences = Preferences()" into the __init__
- # of Person also doens't seem to fix this
- p2.preferences = Preferences()
-
- return p2
-
-
- def test_create(self):
- p1 = self.create_person_one()
- objectstore.flush()
- objectstore.clear()
-
- results = Person.query.all()
-
- self.assertEquals(len(results), 1)
-
- person = results[0]
- self.assertEquals(person.id, p1.id)
- self.assertEquals(len(person.addresses), 2)
- self.assertEquals(person.addresses[0].postal_code, '30338')
-
- @testing.unsupported('mysql')
- def test_update(self):
- p1 = self.create_person_one()
- objectstore.flush()
- objectstore.clear()
-
- person = Person.query.first()
- person.gender = 'F'
- objectstore.flush()
- objectstore.clear()
- self.assertEquals(person.row_version, 2)
-
- person = Person.query.first()
- person.gender = 'M'
- objectstore.flush()
- objectstore.clear()
- self.assertEquals(person.row_version, 3)
-
- #TODO: check that a concurrent modification raises exception
- p1 = Person.query.first()
- s1 = objectstore()
- s2 = create_session()
- objectstore.registry.set(s2)
- p2 = Person.query.first()
- p1.first_name = "jack"
- p2.first_name = "ed"
- objectstore.flush()
- try:
- objectstore.registry.set(s1)
- objectstore.flush()
- # Only dialects with a sane rowcount can detect the ConcurrentModificationError
- if testing.db.dialect.supports_sane_rowcount:
- assert False
- except exceptions.ConcurrentModificationError:
- pass
-
-
- def test_delete(self):
- p1 = self.create_person_one()
-
- objectstore.flush()
- objectstore.clear()
-
- results = Person.query.all()
- self.assertEquals(len(results), 1)
-
- objectstore.delete(results[0])
- objectstore.flush()
- objectstore.clear()
-
- results = Person.query.all()
- self.assertEquals(len(results), 0)
-
-
- def test_multiple(self):
- p1 = self.create_person_one()
- p2 = self.create_person_two()
-
- objectstore.flush()
- objectstore.clear()
-
- # select and make sure we get back two results
- people = Person.query.all()
- self.assertEquals(len(people), 2)
-
- # make sure that our backwards relationships work
- self.assertEquals(people[0].addresses[0].person.id, p1.id)
- self.assertEquals(people[1].addresses[0].person.id, p2.id)
-
- # try a more complex select
- results = Person.query.filter(
- or_(
- and_(
- Address.c.person_id == Person.c.id,
- Address.c.postal_code.like('30075')
- ),
- and_(
- Person.c.prefs_id == Preferences.c.id,
- Preferences.c.favorite_color == 'Green'
- )
- )
- ).all()
- self.assertEquals(len(results), 2)
-
-
- def test_oneway_backref(self):
- # FIXME: I don't know why, but it seems that my backwards relationship
- # on preferences still ends up being a list even though I pass
- # in uselist=False...
- # FIXED: the backref is a new PropertyLoader which needs its own "uselist".
- # uses a function which I dont think existed when you first wrote ActiveMapper.
- p1 = self.create_person_one()
- self.assertEquals(p1.preferences.person, p1)
- objectstore.flush()
- objectstore.delete(p1)
-
- objectstore.flush()
- objectstore.clear()
-
-
- def test_select_by(self):
- # FIXME: either I don't understand select_by, or it doesn't work.
- # FIXED (as good as we can for now): yup....everyone thinks it works that way....it only
- # generates joins for keyword arguments, not ColumnClause args. would need a new layer of
- # "MapperClause" objects to use properties in expressions. (MB)
-
- p1 = self.create_person_one()
- p2 = self.create_person_two()
-
- objectstore.flush()
- objectstore.clear()
-
- results = Person.query.join('addresses').filter(
- Address.c.postal_code.like('30075')
- ).all()
- self.assertEquals(len(results), 1)
-
- self.assertEquals(Person.query.count(), 2)
-
-class testmanytomany(TestBase):
- def setUpAll(self):
- clear_mappers()
- objectstore.clear()
- global secondarytable, foo, baz
- secondarytable = Table("secondarytable",
- activemapper.metadata,
- Column("foo_id", Integer, ForeignKey("foo.id"),primary_key=True),
- Column("baz_id", Integer, ForeignKey("baz.id"),primary_key=True))
-
- class foo(activemapper.ActiveMapper):
- class mapping:
- name = column(String(30))
-# bazrel = many_to_many('baz', secondarytable, backref='foorel')
-
- class baz(activemapper.ActiveMapper):
- class mapping:
- name = column(String(30))
- foorel = many_to_many("foo", secondarytable, backref='bazrel')
-
- activemapper.metadata.bind = testing.db
- activemapper.create_tables()
-
- # Create a couple of activemapper objects
- def create_objects(self):
- return foo(name='foo1'), baz(name='baz1')
-
- def tearDownAll(self):
- clear_mappers()
- activemapper.drop_tables()
- objectstore.clear()
- def testbasic(self):
- # Set up activemapper objects
- foo1, baz1 = self.create_objects()
-
- objectstore.flush()
- objectstore.clear()
-
- foo1 = foo.query.filter_by(name='foo1').one()
- baz1 = baz.query.filter_by(name='baz1').one()
-
- # Just checking ...
- assert (foo1.name == 'foo1')
- assert (baz1.name == 'baz1')
-
- # Diagnostics ...
- # import sys
- # sys.stderr.write("\nbazrel missing from dir(foo1):\n%s\n" % dir(foo1))
- # sys.stderr.write("\nbazrel in foo1 relations:\n%s\n" % foo1.relations)
-
- # Optimistically based on activemapper one_to_many test, try to append
- # baz1 to foo1.bazrel - (AttributeError: 'foo' object has no attribute 'bazrel')
- foo1.bazrel.append(baz1)
- assert (foo1.bazrel == [baz1])
-
-class testselfreferential(TestBase):
- def setUpAll(self):
- clear_mappers()
- objectstore.clear()
- global TreeNode
- class TreeNode(activemapper.ActiveMapper):
- class mapping:
- id = column(Integer, primary_key=True)
- name = column(String(30))
- parent_id = column(Integer, foreign_key=ForeignKey('treenode.id'))
- children = one_to_many('TreeNode', colname='id', backref='parent')
-
- activemapper.metadata.bind = testing.db
- activemapper.create_tables()
- def tearDownAll(self):
- clear_mappers()
- activemapper.drop_tables()
-
- def testbasic(self):
- t = TreeNode(name='node1')
- t.children.append(TreeNode(name='node2'))
- t.children.append(TreeNode(name='node3'))
- objectstore.flush()
- objectstore.clear()
-
- t = TreeNode.query.filter_by(name='node1').one()
- assert (t.name == 'node1')
- assert (t.children[0].name == 'node2')
- assert (t.children[1].name == 'node3')
- assert (t.children[1].parent is t)
-
- objectstore.clear()
- t = TreeNode.query.filter_by(name='node3').one()
- assert (t.parent is TreeNode.query.filter_by(name='node1').one())
-
-if __name__ == '__main__':
- testenv.main()
import doctest, sys, unittest
def suite():
- unittest_modules = ['ext.activemapper',
- 'ext.assignmapper',
+ unittest_modules = [
'ext.declarative',
'ext.orderinglist',
'ext.associationproxy']
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper
-from sqlalchemy.ext.assignmapper import assign_mapper
-from sqlalchemy.ext.sessioncontext import SessionContext
-from testlib import *
-
-
-class AssignMapperTest(TestBase):
- def setUpAll(self):
- global metadata, table, table2
- metadata = MetaData(testing.db)
- table = Table('sometable', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(30)))
- table2 = Table('someothertable', metadata,
- Column('id', Integer, primary_key=True),
- Column('someid', None, ForeignKey('sometable.id'))
- )
- metadata.create_all()
-
- @testing.uses_deprecated('SessionContext', 'assign_mapper')
- def setUp(self):
- global SomeObject, SomeOtherObject, ctx
- class SomeObject(object):pass
- class SomeOtherObject(object):pass
-
- ctx = SessionContext(create_session)
- assign_mapper(ctx, SomeObject, table, properties={
- 'options':relation(SomeOtherObject)
- })
- assign_mapper(ctx, SomeOtherObject, table2)
-
- s = SomeObject()
- s.id = 1
- s.data = 'hello'
- sso = SomeOtherObject()
- s.options.append(sso)
- ctx.current.flush()
- ctx.current.clear()
-
- def tearDownAll(self):
- metadata.drop_all()
-
- def tearDown(self):
- for table in metadata.table_iterator(reverse=True):
- table.delete().execute()
- clear_mappers()
-
- @testing.uses_deprecated('assign_mapper')
- def test_override_attributes(self):
-
- sso = SomeOtherObject.query().first()
-
- assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
-
- s2 = SomeObject(someid=12)
- s3 = SomeOtherObject(someid=123, bogus=345)
-
- class ValidatedOtherObject(object):pass
- assign_mapper(ctx, ValidatedOtherObject, table2, validate=True)
-
- v1 = ValidatedOtherObject(someid=12)
- try:
- v2 = ValidatedOtherObject(someid=12, bogus=345)
- assert False
- except exceptions.ArgumentError:
- pass
-
- @testing.uses_deprecated('assign_mapper')
- def test_dont_clobber_methods(self):
- class MyClass(object):
- def expunge(self):
- return "an expunge !"
-
- assign_mapper(ctx, MyClass, table2)
-
- assert MyClass().expunge() == "an expunge !"
-
-
-if __name__ == '__main__':
- testenv.main()
from sqlalchemy.orm.interfaces import MapperExtension
from sqlalchemy.ext.declarative import declarative_base, declared_synonym, \
synonym_for, comparable_using
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from testlib.fixtures import Base as Fixture
from testlib import *
id = Column(Integer, primary_key=True)
foo = column_property(User.id==5)
- self.assertRaises(exceptions.InvalidRequestError, go)
+ self.assertRaises(exc.InvalidRequestError, go)
def test_add_prop(self):
class User(Base, Fixture):
name = Column('name', String(50))
assert False
self.assertRaisesMessage(
- exceptions.ArgumentError,
+ exc.ArgumentError,
"Mapper Mapper|User|users could not assemble any primary key",
define)
def suite():
modules_to_test = (
- 'orm.attributes',
+ 'orm.attributes',
+ 'orm.extendedattr',
+ 'orm.instrumentation',
'orm.query',
'orm.lazy_relations',
'orm.eager_relations',
'orm.assorted_eager',
'orm.naturalpks',
- 'orm.sessioncontext',
'orm.unitofwork',
'orm.session',
+ 'orm.transaction',
+ 'orm.scoping',
'orm.cascade',
'orm.relationships',
'orm.association',
'orm.merge',
'orm.pickled',
'orm.memusage',
+ 'orm.utils',
'orm.cycles',
'orm.manytomany',
'orm.onetoone',
'orm.dynamic',
+
+ 'orm.deprecations',
)
alltests = unittest.TestSuite()
for name in modules_to_test:
from testlib import *
class AssociationTest(TestBase):
- @testing.uses_deprecated('association option')
def setUpAll(self):
global items, item_keywords, keywords, metadata, Item, Keyword, KeywordAssociation
metadata = MetaData(testing.db)
'keyword':relation(Keyword, lazy=False)
}, primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], order_by=[item_keywords.c.data])
mapper(Item, items, properties={
- 'keywords' : relation(KeywordAssociation, association=Keyword)
+ 'keywords' : relation(KeywordAssociation, cascade="all, delete-orphan")
})
def tearDown(self):
print loaded
self.assert_(saved == loaded)
- @testing.uses_deprecated('association option')
def testdelete(self):
sess = create_session()
item1 = Item('item1')
mapper(Originals, table_originals, order_by=Originals.order,
properties={
- 'people': relation(IsAuthor, association=People),
+ 'people': relation(IsAuthor, cascade="all, delete-orphan"),
'authors': relation(People, secondary=table_isauthor, backref='written',
primaryjoin=and_(table_originals.c.ID==table_isauthor.c.OriginalsID,
table_isauthor.c.Kind=='A')),
'date': table_originals.c.Date,
})
mapper(People, table_people, order_by=People.order, properties= {
- 'originals': relation(IsAuthor, association=Originals),
+ 'originals': relation(IsAuthor, cascade="all, delete-orphan"),
'name': table_people.c.Name,
'country': table_people.c.Country,
})
import random, datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
from testlib import *
from testlib import fixtures
print result
assert result == [u'1 Some Category', u'3 Some Category']
- @testing.uses_deprecated('//select')
- def test_withouteagerload_deprecated(self):
- s = create_session()
- l=s.query(Test).select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
- from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
- print result
- assert result == [u'1 Some Category', u'3 Some Category']
-
def test_witheagerload(self):
"""test that an eagerload locates the correct "from" clause with
which to attach to, when presented with a query that already has a complicated from clause."""
print result
assert result == [u'1 Some Category', u'3 Some Category']
- @testing.uses_deprecated('//select')
- def test_witheagerload_deprecated(self):
- """As test_witheagerload, but via select()."""
- s = create_session()
- q=s.query(Test).options(eagerload('category'))
- l=q.select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
- from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
- print result
- assert result == [u'1 Some Category', u'3 Some Category']
-
def test_dslish(self):
"""test the same as witheagerload except using generative"""
s = create_session()
print result
assert result == [u'3 Some Category']
- @testing.unsupported('sybase')
- @testing.uses_deprecated('//select', '//join_to')
- def test_withoutouterjoin_literal_deprecated(self):
- s = create_session()
- q=s.query(Test).options(eagerload('category'))
- l=q.select( (tests.c.owner_id==1) & ('options.someoption is null or options.someoption=%s' % false) & q.join_to('owner_option') )
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
- print result
- assert result == [u'3 Some Category']
-
def test_withoutouterjoin(self):
s = create_session()
q=s.query(Test).options(eagerload('category'))
print result
assert result == [u'3 Some Category']
- @testing.uses_deprecated('//select', '//join_to', '//join_via')
- def test_withoutouterjoin_deprecated(self):
- s = create_session()
- q=s.query(Test).options(eagerload('category'))
- l=q.select( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) & q.join_to('owner_option') )
- result = ["%d %s" % ( t.id,t.category.name ) for t in l]
- print result
- assert result == [u'3 Some Category']
-
class EagerTest2(TestBase, AssertsExecutionResults):
def setUpAll(self):
global metadata, middle, left, right
sess.flush()
q = sess.query(Department)
- q = q.join('employees').filter(Employee.c.name.startswith('J')).distinct().order_by([desc(Department.c.name)])
+ q = q.join('employees').filter(Employee.name.startswith('J')).distinct().order_by([desc(Department.name)])
assert q.count() == 2
assert q[0] is d2
x.inheritedParts
class EagerTest7(ORMTest):
- @testing.uses_deprecated('SessionContext')
def define_tables(self, metadata):
global companies_table, addresses_table, invoice_table, phones_table, items_table, ctx
global Company, Address, Phone, Item,Invoice
- ctx = SessionContext(create_session)
+ ctx = scoped_session(create_session)
companies_table = Table('companies', metadata,
Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True),
def __repr__(self):
return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty)
- @testing.uses_deprecated('SessionContext')
def testone(self):
"""tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated
the bug, which is that when the single Company is loaded, no further processing of the rows
occurred in order to load the Company's second Address object."""
mapper(Address, addresses_table, properties={
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
mapper(Company, companies_table, properties={
'addresses' : relation(Address, lazy=False),
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
mapper(Invoice, invoice_table, properties={
'company': relation(Company, lazy=False, )
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
c1 = Company()
c1.company_name = 'company 1'
i1.date = datetime.datetime.now()
i1.company = c1
- ctx.current.flush()
+ ctx.flush()
company_id = c1.company_id
invoice_id = i1.invoice_id
- ctx.current.clear()
+ ctx.clear()
- c = ctx.current.query(Company).get(company_id)
+ c = ctx.query(Company).get(company_id)
- ctx.current.clear()
+ ctx.clear()
- i = ctx.current.query(Invoice).get(invoice_id)
+ i = ctx.query(Invoice).get(invoice_id)
print repr(c)
print repr(i.company)
def testtwo(self):
"""this is the original testcase that includes various complicating factors"""
- mapper(Phone, phones_table, extension=ctx.mapper_extension)
+ mapper(Phone, phones_table, extension=ctx.extension)
mapper(Address, addresses_table, properties={
'phones': relation(Phone, lazy=False, backref='address')
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
mapper(Company, companies_table, properties={
'addresses' : relation(Address, lazy=False, backref='company'),
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
- mapper(Item, items_table, extension=ctx.mapper_extension)
+ mapper(Item, items_table, extension=ctx.extension)
mapper(Invoice, invoice_table, properties={
'items': relation(Item, lazy=False, backref='invoice'),
'company': relation(Company, lazy=False, backref='invoices')
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
- ctx.current.clear()
+ ctx.clear()
c1 = Company()
c1.company_name = 'company 1'
c1.addresses.append(a2)
- ctx.current.flush()
+ ctx.flush()
company_id = c1.company_id
- ctx.current.clear()
+ ctx.clear()
- a = ctx.current.query(Company).get(company_id)
+ a = ctx.query(Company).get(company_id)
print repr(a)
# set up an invoice
item3.qty = 3
item3.invoice = i1
- ctx.current.flush()
+ ctx.flush()
invoice_id = i1.invoice_id
- ctx.current.clear()
+ ctx.clear()
- c = ctx.current.query(Company).get(company_id)
+ c = ctx.query(Company).get(company_id)
print repr(c)
- ctx.current.clear()
+ ctx.clear()
- i = ctx.current.query(Invoice).get(invoice_id)
+ i = ctx.query(Invoice).get(invoice_id)
assert repr(i.company) == repr(c), repr(i.company) + " does not match " + repr(c)
import pickle
import sqlalchemy.orm.attributes as attributes
from sqlalchemy.orm.collections import collection
-from sqlalchemy import exceptions
+from sqlalchemy.orm.interfaces import AttributeExtension
+from sqlalchemy import exc as sa_exc
from testlib import *
from testlib import fixtures
-ROLLBACK_SUPPORTED=False
-
-# these test classes defined at the module
-# level to support pickling
-class MyTest(object):pass
-class MyTest2(object):pass
+# global for pickling tests
+MyTest = None
+MyTest2 = None
class AttributesTest(TestBase):
+ def setUp(self):
+ global MyTest, MyTest2
+ class MyTest(object): pass
+ class MyTest2(object): pass
+
+ def tearDown(self):
+ global MyTest, MyTest2
+ MyTest, MyTest2 = None, None
def test_basic(self):
class User(object):pass
u.email_address = 'lala@123.com'
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
- u._state.commit_all()
+ attributes.instance_state(u).commit_all()
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
u.user_name = 'heythere'
class Foo(object):pass
data = {'a':'this is a', 'b':12}
- def loader(instance, keys):
+ def loader(state, keys):
for k in keys:
- instance.__dict__[k] = data[k]
+ state.dict[k] = data[k]
return attributes.ATTR_WAS_SET
- attributes.register_class(Foo, deferred_scalar_loader=loader)
+ attributes.register_class(Foo)
+ manager = attributes.manager_of_class(Foo)
+ manager.deferred_scalar_loader = loader
attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
f = Foo()
- f._state.expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(None)
self.assertEquals(f.a, "this is a")
self.assertEquals(f.b, 12)
f.a = "this is some new a"
- f._state.expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(None)
self.assertEquals(f.a, "this is a")
self.assertEquals(f.b, 12)
- f._state.expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(None)
f.a = "this is another new a"
self.assertEquals(f.a, "this is another new a")
self.assertEquals(f.b, 12)
- f._state.expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(None)
self.assertEquals(f.a, "this is a")
self.assertEquals(f.b, 12)
self.assertEquals(f.a, None)
self.assertEquals(f.b, 12)
- f._state.commit_all()
+ attributes.instance_state(f).commit_all()
self.assertEquals(f.a, None)
self.assertEquals(f.b, 12)
def test_deferred_pickleable(self):
data = {'a':'this is a', 'b':12}
- def loader(instance, keys):
+ def loader(state, keys):
for k in keys:
- instance.__dict__[k] = data[k]
+ state.dict[k] = data[k]
return attributes.ATTR_WAS_SET
- attributes.register_class(MyTest, deferred_scalar_loader=loader)
+ attributes.register_class(MyTest)
+ manager = attributes.manager_of_class(MyTest)
+ manager.deferred_scalar_loader=loader
attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False)
attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
m = MyTest()
- m._state.expire_attributes(None)
+ attributes.instance_state(m).expire_attributes(None)
assert 'a' not in m.__dict__
m2 = pickle.loads(pickle.dumps(m))
assert 'a' not in m2.__dict__
u.addresses.append(a)
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
- u, a._state.commit_all()
+ u, attributes.instance_state(a).commit_all()
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
u.user_name = 'heythere'
u.addresses.append(a)
self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
+ def test_scalar_listener(self):
+ # listeners on ScalarAttributeImpl and MutableScalarAttributeImpl aren't used normally.
+ # test that they work for the benefit of user extensions
+ class Foo(object):
+ pass
+
+ results = []
+ class ReceiveEvents(AttributeExtension):
+ def append(self, state, child, initiator):
+ assert False
+
+ def remove(self, state, child, initiator):
+ results.append(("remove", state.obj(), child))
+
+ def set(self, state, child, oldchild, initiator):
+ results.append(("set", state.obj(), child, oldchild))
+
+ attributes.register_class(Foo)
+ attributes.register_attribute(Foo, 'x', uselist=False, mutable_scalars=False, useobject=False, extension=ReceiveEvents())
+ attributes.register_attribute(Foo, 'y', uselist=False, mutable_scalars=True, useobject=False, copy_function=lambda x:x, extension=ReceiveEvents())
+
+ f = Foo()
+ f.x = 5
+ f.x = 17
+ del f.x
+ f.y = [1,2,3]
+ f.y = [4,5,6]
+ del f.y
+
+ self.assertEquals(results, [
+ ('set', f, 5, None),
+ ('set', f, 17, 5),
+ ('remove', f, 17),
+ ('set', f, [1,2,3], None),
+ ('set', f, [4,5,6], [1,2,3]),
+ ('remove', f, [4,5,6])
+ ])
+
+
def test_lazytrackparent(self):
"""test that the "hasparent" flag works properly when lazy loaders and backrefs are used"""
# create objects as if they'd been freshly loaded from the database (without history)
b = Blog()
p1 = Post()
- b._state.set_callable('posts', lambda:[p1])
- p1._state.set_callable('blog', lambda:b)
- p1, b._state.commit_all()
+ attributes.instance_state(b).set_callable('posts', lambda:[p1])
+ attributes.instance_state(p1).set_callable('blog', lambda:b)
+ p1, attributes.instance_state(b).commit_all()
# no orphans (called before the lazy loaders fire off)
assert attributes.has_parent(Blog, p1, 'posts', optimistic=True)
states = set()
class Foo(object):
def __init__(self):
- states.add(self._state)
+ states.add(attributes.instance_state(self))
class Bar(Foo):
def __init__(self):
- states.add(self._state)
+ states.add(attributes.instance_state(self))
Foo.__init__(self)
el = Element()
x = Bar()
x.element = el
- self.assertEquals(attributes.get_history(x._state, 'element'), ([el],[], []))
- x._state.commit_all()
+ self.assertEquals(attributes.get_history(attributes.instance_state(x), 'element'), ([el],[], []))
+ attributes.instance_state(x).commit_all()
- (added, unchanged, deleted) = attributes.get_history(x._state, 'element')
+ (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element')
assert added == []
assert unchanged == [el]
attributes.register_attribute(Bar, 'id', uselist=False, useobject=True)
x = Foo()
- x._state.commit_all()
+ attributes.instance_state(x).commit_all()
x.col2.append(bar4)
- self.assertEquals(attributes.get_history(x._state, 'col2'), ([bar4], [bar1, bar2, bar3], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], []))
def test_parenttrack(self):
class Foo(object):pass
attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False)
x = Foo()
x.element = ['one', 'two', 'three']
- x._state.commit_all()
+ attributes.instance_state(x).commit_all()
x.element[1] = 'five'
- assert x._state.is_modified()
+ assert attributes.instance_state(x).check_modified()
attributes.unregister_class(Foo)
attributes.register_attribute(Foo, 'element', uselist=False, useobject=False)
x = Foo()
x.element = ['one', 'two', 'three']
- x._state.commit_all()
+ attributes.instance_state(x).commit_all()
x.element[1] = 'five'
- assert not x._state.is_modified()
+ assert not attributes.instance_state(x).check_modified()
def test_descriptorattributes(self):
"""changeset: 1633 broke ability to use ORM to map classes with unusual
This is a simple regression test to prevent that defect.
"""
class des(object):
- def __get__(self, instance, owner): raise AttributeError('fake attribute')
+ def __get__(self, instance, owner):
+ raise AttributeError('fake attribute')
class Foo(object):
A = des()
-
+ attributes.register_class(Foo)
attributes.unregister_class(Foo)
def test_collectionclasses(self):
class Foo(object):pass
attributes.register_class(Foo)
+
attributes.register_attribute(Foo, "collection", uselist=True, typecallable=set, useobject=True)
+ assert attributes.manager_of_class(Foo).is_instrumented("collection")
assert isinstance(Foo().collection, set)
attributes.unregister_attribute(Foo, "collection")
-
+ assert not attributes.manager_of_class(Foo).is_instrumented("collection")
+
try:
attributes.register_attribute(Foo, "collection", uselist=True, typecallable=dict, useobject=True)
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Type InstrumentedDict must elect an appender method to be a collection class"
class MyDict(dict):
try:
attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True)
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Type MyColl must elect an appender method to be a collection class"
class MyColl(object):
try:
Foo().collection
assert True
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert False
j.port = None
self.assert_(p.jack is None)
-class DeferredBackrefTest(TestBase):
+class PendingBackrefTest(TestBase):
def setUp(self):
global Post, Blog, called, lazy_load
b = Blog("blog 1")
p = Post("post 4")
+
p.blog = b
p = Post("post 5")
p.blog = b
# calling backref calls the callable, populates extra posts
assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")]
assert called[0] == 1
+
+ def test_lazy_history(self):
+ global lazy_load
+
+ p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3")
+ lazy_load = [p1, p2, p3]
+
+ b = Blog("blog 1")
+ p = Post("post 4")
+ p.blog = b
+
+ p4 = Post("post 5")
+ p4.blog = b
+ assert called[0] == 0
+ self.assertEquals(attributes.instance_state(b).get_history('posts'), ([p, p4], [p1, p2, p3], []))
+ assert called[0] == 1
def test_lazy_remove(self):
global lazy_load
attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False)
f = Foo()
- self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+ self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
f.someattr = 3
- self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+ self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
f = Foo()
f.someattr = 3
- self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+ self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
- f._state.commit(['someattr'])
- self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), 3)
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), 3)
def test_scalar(self):
class Foo(fixtures.Base):
# case 1. new object
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr = "hi"
- self.assertEquals(attributes.get_history(f._state, 'someattr'), (['hi'], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['hi'], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['hi'], []))
f.someattr = 'there'
- self.assertEquals(attributes.get_history(f._state, 'someattr'), (['there'], [], ['hi']))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], [], ['hi']))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['there'], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['there'], []))
del f.someattr
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], ['there']))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], ['there']))
# case 2. object with direct dictionary settings (similar to a load operation)
f = Foo()
f.__dict__['someattr'] = 'new'
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
f.someattr = 'old'
- self.assertEquals(attributes.get_history(f._state, 'someattr'), (['old'], [], ['new']))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], [], ['new']))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['old'], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['old'], []))
# setting None on uninitialized is currently a change for a scalar attribute
# no lazyload occurs so this allows overwrite operation to proceed
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
+ print f._foostate.committed_state
f.someattr = None
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], []))
+ print f._foostate.committed_state, f._foostate.dict
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], []))
f = Foo()
f.__dict__['someattr'] = 'new'
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
f.someattr = None
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], ['new']))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], ['new']))
+ # set same value twice
+ f = Foo()
+ attributes.instance_state(f).commit(['someattr'])
+ f.someattr = 'one'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], [], []))
+ f.someattr = 'two'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], [], []))
+
+
def test_mutable_scalar(self):
class Foo(fixtures.Base):
pass
# case 1. new object
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr = {'foo':'hi'}
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'hi'}], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'hi'}], []))
- self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'})
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'hi'}], []))
+ self.assertEquals(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
f.someattr['foo'] = 'there'
- self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'})
+ self.assertEquals(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'there'}], [], [{'foo':'hi'}]))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], [], [{'foo':'hi'}]))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'there'}], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'there'}], []))
# case 2. object with direct dictionary settings (similar to a load operation)
f = Foo()
f.__dict__['someattr'] = {'foo':'new'}
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'new'}], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'new'}], []))
f.someattr = {'foo':'old'}
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'old'}], [], [{'foo':'new'}]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], [], [{'foo':'new'}]))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'old'}], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'old'}], []))
def test_use_object(self):
# case 1. new object
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
f.someattr = hi
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
f.someattr = there
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi]))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi]))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [there], []))
del f.someattr
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [there]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], [there]))
# case 2. object with direct dictionary settings (similar to a load operation)
f = Foo()
- f.__dict__['someattr'] = new
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ f.__dict__['someattr'] = 'new'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
f.someattr = old
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], ['new']))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old], []))
# setting None on uninitialized is currently not a change for an object attribute
# (this is different than scalar attribute). a lazyload has occured so if its
# None, its really None
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
f.someattr = None
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
f = Foo()
- f.__dict__['someattr'] = new
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ f.__dict__['someattr'] = 'new'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
f.someattr = None
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [new]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], ['new']))
+
+ # set same value twice
+ f = Foo()
+ attributes.instance_state(f).commit(['someattr'])
+ f.someattr = 'one'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], [], []))
+ f.someattr = 'two'
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], [], []))
def test_object_collections_set(self):
class Foo(fixtures.Base):
# case 1. new object
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr = [hi]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
f.someattr = [there]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi]))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi]))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [there], []))
f.someattr = [hi]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [there]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [there]))
f.someattr = [old, new]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [], [there]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [], [there]))
# case 2. object with direct settings (similar to a load operation)
f = Foo()
- collection = attributes.init_collection(f, 'someattr')
+ collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
collection.append_without_event(new)
- f._state.commit_all()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ attributes.instance_state(f).commit_all()
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
f.someattr = [old]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new]))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old], []))
def test_dict_collections(self):
class Foo(fixtures.Base):
new = Bar(name='new')
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr['hi'] = hi
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
f.someattr['there'] = there
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([hi, there]), set([]), set([])))
+ self.assertEquals(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set([]), set([])))
- f._state.commit(['someattr'])
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([]), set([hi, there]), set([])))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([]), set([hi, there]), set([])))
def test_object_collections_mutate(self):
class Foo(fixtures.Base):
# case 1. new object
f = Foo(id=1)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
f.someattr.append(hi)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
f.someattr.append(there)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [hi], []))
- f._state.commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], []))
+ attributes.instance_state(f).commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, there], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi, there], []))
f.someattr.remove(there)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [there]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], [there]))
f.someattr.append(old)
f.someattr.append(new)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [hi], [there]))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, old, new], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there]))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi, old, new], []))
f.someattr.pop(0)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old, new], [hi]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old, new], [hi]))
# case 2. object with direct settings (similar to a load operation)
f = Foo()
f.__dict__['id'] = 1
- collection = attributes.init_collection(f, 'someattr')
+ collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
collection.append_without_event(new)
- f._state.commit_all()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ attributes.instance_state(f).commit_all()
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
f.someattr.append(old)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [new], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], []))
- f._state.commit(['someattr'])
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new, old], []))
+ attributes.instance_state(f).commit(['someattr'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new, old], []))
f = Foo()
- collection = attributes.init_collection(f, 'someattr')
+ collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
collection.append_without_event(new)
- f._state.commit_all()
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+ attributes.instance_state(f).commit_all()
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
f.id = 1
f.someattr.remove(new)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [new]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [new]))
# case 3. mixing appends with sets
f = Foo()
f.someattr.append(hi)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
f.someattr.append(there)
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi, there], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there], [], []))
f.someattr = [there]
- self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], []))
def test_collections_via_backref(self):
class Foo(fixtures.Base):
f1 = Foo()
b1 = Bar()
- self.assertEquals(attributes.get_history(f1._state, 'bars'), ([], [], []))
- self.assertEquals(attributes.get_history(b1._state, 'foo'), ([], [None], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([], [None], []))
#b1.foo = f1
f1.bars.append(b1)
- self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1], [], []))
- self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], [], []))
b2 = Bar()
f1.bars.append(b2)
- self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1, b2], [], []))
- self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], []))
- self.assertEquals(attributes.get_history(b2._state, 'foo'), ([f1], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1, b2], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(b2), 'foo'), ([f1], [], []))
def test_lazy_backref_collections(self):
class Foo(fixtures.Base):
f = Foo()
bar4 = Bar()
bar4.foo = f
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], []))
lazy_load = None
f = Foo()
bar4 = Bar()
bar4.foo = f
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [], []))
lazy_load = [bar1, bar2, bar3]
- f._state.expire_attributes(['bars'])
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], []))
+ attributes.instance_state(f).expire_attributes(['bars'])
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar2, bar3], []))
def test_collections_via_lazyload(self):
class Foo(fixtures.Base):
f = Foo()
f.bars = []
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [], [bar1, bar2, bar3]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [], [bar1, bar2, bar3]))
f = Foo()
f.bars.append(bar4)
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []) )
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], []) )
f = Foo()
f.bars.remove(bar2)
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2]))
f.bars.append(bar4)
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar3], [bar2]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar3], [bar2]))
f = Foo()
del f.bars[1]
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2]))
lazy_load = None
f = Foo()
f.bars.append(bar2)
- self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar2], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar2], [], []))
def test_scalar_via_lazyload(self):
class Foo(fixtures.Base):
f = Foo()
self.assertEquals(f.bar, "hi")
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([], ["hi"], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], ["hi"], []))
f = Foo()
f.bar = None
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], []))
f = Foo()
f.bar = "there"
- self.assertEquals(attributes.get_history(f._state, 'bar'), (["there"], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), (["there"], [], []))
f.bar = "hi"
- self.assertEquals(attributes.get_history(f._state, 'bar'), (["hi"], [], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), (["hi"], [], []))
f = Foo()
self.assertEquals(f.bar, "hi")
del f.bar
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [], ["hi"]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [], ["hi"]))
assert f.bar is None
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], ["hi"]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], ["hi"]))
def test_scalar_object_via_lazyload(self):
class Foo(fixtures.Base):
# operations
f = Foo()
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [bar1], []))
f = Foo()
f.bar = None
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
f = Foo()
f.bar = bar2
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([bar2], [], [bar1]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([bar2], [], [bar1]))
f.bar = bar1
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [bar1], []))
f = Foo()
self.assertEquals(f.bar, bar1)
del f.bar
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
assert f.bar is None
- self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
+
if __name__ == "__main__":
testenv.main()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
from testlib import *
from testlib import fixtures
try:
sess.flush()
assert False
- except exceptions.FlushError, e:
+ except orm_exc.FlushError, e:
assert "is an orphan" in str(e)
def test_delete(self):
s.save(a)
try:
s.flush()
- except exceptions.FlushError, e:
+ except orm_exc.FlushError, e:
pass
assert a.address_id is None, "Error: address should not be persistent"
try:
session.flush()
assert False
- except exceptions.FlushError, e:
+ except orm_exc.FlushError, e:
assert True
class CollectionAssignmentOrphanTest(ORMTest):
self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]))
a1 = sess.query(A).get(a1.id)
- assert not class_mapper(B)._is_orphan(a1.bs[0])
+ assert not class_mapper(B)._is_orphan(attributes.instance_state(a1.bs[0]))
a1.bs[0].foo='b2modified'
a1.bs[1].foo='b3modified'
sess.flush()
import sys
from operator import and_
from sqlalchemy import *
-import sqlalchemy.exceptions as exceptions
+import sqlalchemy.exc as sa_exc
from sqlalchemy.orm import create_session, mapper, relation, \
interfaces, attributes
import sqlalchemy.orm.collections as collections
self._test_adapter(dict, dictable_entity,
to_set=lambda c: set(c.values()))
self.assert_(False)
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
try:
self._test_dict(dict)
self.assert_(False)
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
def test_dict_subclass(self):
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
from testlib import *
try:
class_mapper(Product).compile()
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e).index("Error creating backref ") > -1
def testthree(self):
try:
compile_mappers()
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e).index("Error creating backref") > -1
if __name__ == '__main__':
Column("child2_data", String(50))
)
meta.create_all()
+
def tearDownAll(self):
meta.drop_all()
+
def testmanytooneonly(self):
"""test similar to SelfReferentialTest.testmanytooneonly"""
+
class Parent(object):
- pass
+ pass
mapper(Parent, parent)
class Child1(Parent):
- pass
+ pass
mapper(Child1, child1, inherits=Parent)
class Child2(Parent):
- pass
+ pass
mapper(Child2, child2, properties={
"child1": relation(Child1,
class InheritTestTwo(ORMTest):
"""the fix in BiDirectionalManyToOneTest raised this issue, regarding
the 'circular sort' containing UOWTasks that were still polymorphic, which could
- create duplicate entries in the final sort"""
+ create duplicate entries in the final sort
+
+ """
def define_tables(self, metadata):
global a, b, c
a = Table('a', metadata,
Column('data', String(30)),
Column('aid', Integer, ForeignKey('a.id', use_alter=True, name="foo")),
)
+
def test_flush(self):
class A(object):pass
class B(A):pass
def testcycle(self):
"""this test has a peculiar aspect in that it doesnt create as many dependent
- relationships as the other tests, and revealed a small glitch in the circular dependency sorting."""
+ relationships as the other tests, and revealed a small glitch in the circular dependency sorting.
+
+ """
class Person(object):
- pass
+ pass
class Ball(object):
- pass
+ pass
Ball.mapper = mapper(Ball, ball)
Person.mapper = mapper(Person, person, properties= dict(
- balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id),
- favorateBall = relation(Ball.mapper, primaryjoin=person.c.favorite_ball_id==ball.c.id, remote_side=person.c.favorite_ball_id),
+ balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id),
+ favorateBall = relation(Ball.mapper, primaryjoin=person.c.favorite_ball_id==ball.c.id, remote_side=ball.c.id),
)
)
p = Person()
p.balls.append(b)
sess = create_session()
- sess.save(b)
- sess.save(b)
+ sess.save(p)
sess.flush()
-
+
def testpostupdate_m2o(self):
"""tests a cycle between two rows, with a post_update on the many-to-one"""
class Person(object):
a_table.create()
def tearDownAll(self):
a_table.drop()
+
def testbasic(self):
"""test that post_update remembers to be involved in update operations as well,
since it replaces the normal dependency processing completely [ticket:413]"""
--- /dev/null
+"""The collection of modern alternatives to deprecated & removed functionality.
+
+Collects specimens of old ORM code and explicitly covers the recommended
+modern (i.e. not deprecated) alternative to them. The tests snippets here can
+be migrated directly to the wiki, docs, etc.
+
+"""
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
+
+users, addresses = None, None
+session = None
+
+class Base(object):
+ def __init__(self, **kw):
+ for k, v in kw.iteritems():
+ setattr(self, k, v)
+
+class User(Base): pass
+class Address(Base): pass
+
+
+class QueryAlternativesTest(ORMTest):
+ '''Collects modern idioms for Queries
+
+ The docstring for each test case serves as miniature documentation about
+ the deprecated use case, and the test body illustrates (and covers) the
+ intended replacement code to accomplish the same task.
+
+ Documenting the "old way" including the argument signature helps these
+ cases remain useful to readers even after the deprecated method has been
+ removed from the modern codebase.
+
+ Format:
+
+ def test_deprecated_thing(self):
+ """Query.methodname(old, arg, **signature)
+
+ output = session.query(User).deprecatedmethod(inputs)
+
+ """
+ # 0.4+
+ output = session.query(User).newway(inputs)
+ assert output is correct
+
+ # 0.5+
+ output = session.query(User).evennewerway(inputs)
+ assert output is correct
+
+ '''
+ keep_mappers = True
+ keep_data = True
+
+ def define_tables(self, metadata):
+ global users_table, addresses_table
+ users_table = Table(
+ 'users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(64)))
+
+ addresses_table = Table(
+ 'addresses', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('user_id', Integer, ForeignKey('users.id')),
+ Column('email_address', String(128)),
+ Column('purpose', String(16)),
+ Column('bounces', Integer, default=0))
+
+ def setup_mappers(self):
+ mapper(User, users_table, properties=dict(
+ addresses=relation(Address, backref='user'),
+ ))
+ mapper(Address, addresses_table)
+
+ def insert_data(self):
+ user_cols = ('id', 'name')
+ user_rows = ((1, 'jack'), (2, 'ed'), (3, 'fred'), (4, 'chuck'))
+ users_table.insert().execute(
+ [dict(zip(user_cols, row)) for row in user_rows])
+
+ add_cols = ('id', 'user_id', 'email_address', 'purpose', 'bounces')
+ add_rows = (
+ (1, 1, 'jack@jack.home', 'Personal', 0),
+ (2, 1, 'jack@jack.bizz', 'Work', 1),
+ (3, 2, 'ed@foo.bar', 'Personal', 0),
+ (4, 3, 'fred@the.fred', 'Personal', 10))
+
+ addresses_table.insert().execute(
+ [dict(zip(add_cols, row)) for row in add_rows])
+
+ def setUp(self):
+ super(QueryAlternativesTest, self).setUp()
+ global session
+ if session is None:
+ session = create_session()
+
+ def tearDown(self):
+ super(QueryAlternativesTest, self).tearDown()
+ session.clear()
+
+ ######################################################################
+
+ def test_apply_max(self):
+ """Query.apply_max(col)
+
+ max = session.query(Address).apply_max(Address.bounces)
+
+ """
+ # 0.5.0
+ maxes = list(session.query(Address).values(func.max(Address.bounces)))
+ max = maxes[0][0]
+ assert max == 10
+
+ max = session.query(func.max(Address.bounces)).one()[0]
+ assert max == 10
+
+ def test_apply_min(self):
+ """Query.apply_min(col)
+
+ min = session.query(Address).apply_min(Address.bounces)
+
+ """
+ # 0.5.0
+ mins = list(session.query(Address).values(func.min(Address.bounces)))
+ min = mins[0][0]
+ assert min == 0
+
+ min = session.query(func.min(Address.bounces)).one()[0]
+ assert min == 0
+
+ def test_apply_avg(self):
+ """Query.apply_avg(col)
+
+ avg = session.query(Address).apply_avg(Address.bounces)
+
+ """
+ avgs = list(session.query(Address).values(func.avg(Address.bounces)))
+ avg = avgs[0][0]
+ assert avg > 0 and avg < 10
+
+ avg = session.query(func.avg(Address.bounces)).one()[0]
+ assert avg > 0 and avg < 10
+
+ def test_apply_sum(self):
+ """Query.apply_sum(col)
+
+ avg = session.query(Address).apply_avg(Address.bounces)
+
+ """
+ avgs = list(session.query(Address).values(func.avg(Address.bounces)))
+ avg = avgs[0][0]
+ assert avg > 0 and avg < 10
+
+ avg = session.query(func.avg(Address.bounces)).one()[0]
+ assert avg > 0 and avg < 10
+
+ def test_count_by(self):
+ """Query.count_by(*args, **params)
+
+ num = session.query(Address).count_by(purpose='Personal')
+
+ # old-style implicit *_by join
+ num = session.query(User).count_by(purpose='Personal')
+
+ """
+ num = session.query(Address).filter_by(purpose='Personal').count()
+ assert num == 3, num
+
+ num = (session.query(User).join('addresses').
+ filter(Address.purpose=='Personal')).count()
+ assert num == 3, num
+
+ def test_count_whereclause(self):
+ """Query.count(whereclause=None, params=None, **kwargs)
+
+ num = session.query(Address).count(address_table.c.bounces > 1)
+
+ """
+ num = session.query(Address).filter(Address.bounces > 1).count()
+ assert num == 1, num
+
+ def test_execute(self):
+ """Query.execute(clauseelement, params=None, *args, **kwargs)
+
+ users = session.query(User).execute(users_table.select())
+
+ """
+ users = session.query(User).from_statement(users_table.select()).all()
+ assert len(users) == 4
+
+ def test_get_by(self):
+ """Query.get_by(*args, **params)
+
+ user = session.query(User).get_by(name='ed')
+
+ # 0.3-style implicit *_by join
+ user = session.query(User).get_by(email_addresss='fred@the.fred')
+
+ """
+ user = session.query(User).filter_by(name='ed').first()
+ assert user.name == 'ed'
+
+ user = (session.query(User).join('addresses').
+ filter(Address.email_address=='fred@the.fred')).first()
+ assert user.name == 'fred'
+
+ user = session.query(User).filter(
+ User.addresses.any(Address.email_address=='fred@the.fred')).first()
+ assert user.name == 'fred'
+
+ def test_instances_entities(self):
+ """Query.instances(cursor, *mappers_or_columns, **kwargs)
+
+ sel = users_table.join(addresses_table).select(use_labels=True)
+ res = session.query(User).instances(sel.execute(), Address)
+
+ """
+ sel = users_table.join(addresses_table).select(use_labels=True)
+ res = session.query(User, Address).instances(sel.execute())
+
+ assert len(res) == 4
+ cola, colb = res[0]
+ assert isinstance(cola, User) and isinstance(colb, Address)
+
+
+ def test_join_by(self):
+ """Query.join_by(*args, **params)
+
+ TODO
+ """
+
+ def test_join_to(self):
+ """Query.join_to(key)
+
+ TODO
+ """
+
+ def test_join_via(self):
+ """Query.join_via(keys)
+
+ TODO
+ """
+
+ def test_list(self):
+ """Query.list()
+
+ users = session.query(User).list()
+
+ """
+ users = session.query(User).all()
+ assert len(users) == 4
+
+ def test_scalar(self):
+ """Query.scalar()
+
+ user = session.query(User).filter(User.id==1).scalar()
+
+ """
+ user = session.query(User).filter(User.id==1).first()
+ assert user.id==1
+
+ def test_select(self):
+ """Query.select(arg=None, **kwargs)
+
+ users = session.query(User).select(users_table.c.name != None)
+
+ """
+ users = session.query(User).filter(User.name != None).all()
+ assert len(users) == 4
+
+ def test_select_by(self):
+ """Query.select_by(*args, **params)
+
+ users = session.query(User).select_by(name='fred')
+
+ # 0.3 magic join on *_by methods
+ users = session.query(User).select_by(email_address='fred@the.fred')
+
+ """
+ users = session.query(User).filter_by(name='fred').all()
+ assert len(users) == 1
+
+ users = session.query(User).filter(User.name=='fred').all()
+ assert len(users) == 1
+
+ users = (session.query(User).join('addresses').
+ filter_by(email_address='fred@the.fred')).all()
+ assert len(users) == 1
+
+ users = session.query(User).filter(User.addresses.any(
+ Address.email_address == 'fred@the.fred')).all()
+ assert len(users) == 1
+
+ def test_selectfirst(self):
+ """Query.selectfirst(arg=None, **kwargs)
+
+ bounced = session.query(Address).selectfirst(
+ addresses_table.c.bounces > 0)
+
+ """
+ bounced = session.query(Address).filter(Address.bounces > 0).first()
+ assert bounced.bounces > 0
+
+ def test_selectfirst_by(self):
+ """Query.selectfirst_by(*args, **params)
+
+ onebounce = session.query(Address).selectfirst_by(bounces=1)
+
+ # 0.3 magic join on *_by methods
+ onebounce_user = session.query(User).selectfirst_by(bounces=1)
+
+ """
+ onebounce = session.query(Address).filter_by(bounces=1).first()
+ assert onebounce.bounces == 1
+
+ onebounce_user = (session.query(User).join('addresses').
+ filter_by(bounces=1)).first()
+ assert onebounce_user.name == 'jack'
+
+ onebounce_user = (session.query(User).join('addresses').
+ filter(Address.bounces == 1)).first()
+ assert onebounce_user.name == 'jack'
+
+ onebounce_user = session.query(User).filter(User.addresses.any(
+ Address.bounces == 1)).first()
+ assert onebounce_user.name == 'jack'
+
+
+ def test_selectone(self):
+ """Query.selectone(arg=None, **kwargs)
+
+ ed = session.query(User).selectone(users_table.c.name == 'ed')
+
+ """
+ ed = session.query(User).filter(User.name == 'jack').one()
+
+ def test_selectone_by(self):
+ """Query.selectone_by
+
+ ed = session.query(User).selectone_by(name='ed')
+
+ # 0.3 magic join on *_by methods
+ ed = session.query(User).selectone_by(email_address='ed@foo.bar')
+
+ """
+ ed = session.query(User).filter_by(name='jack').one()
+
+ ed = session.query(User).filter(User.name == 'jack').one()
+
+ ed = session.query(User).join('addresses').filter(
+ Address.email_address == 'ed@foo.bar').one()
+
+ ed = session.query(User).filter(User.addresses.any(
+ Address.email_address == 'ed@foo.bar')).one()
+
+ def test_select_statement(self):
+ """Query.select_statement(statement, **params)
+
+ users = session.query(User).select_statement(users_table.select())
+
+ """
+ users = session.query(User).from_statement(users_table.select()).all()
+ assert len(users) == 4
+
+ def test_select_text(self):
+ """Query.select_text(text, **params)
+
+ users = session.query(User).select_text('SELECT * FROM users')
+
+ """
+ users = session.query(User).from_statement('SELECT * FROM users').all()
+ assert len(users) == 4
+
+ def test_select_whereclause(self):
+ """Query.select_whereclause(whereclause=None, params=None, **kwargs)
+
+
+ users = session,query(User).select_whereclause(users.c.name=='ed')
+ users = session.query(User).select_whereclause("name='ed'")
+
+ """
+ users = session.query(User).filter(User.name=='ed').all()
+ assert len(users) == 1 and users[0].name == 'ed'
+
+ users = session.query(User).filter("name='ed'").all()
+ assert len(users) == 1 and users[0].name == 'ed'
+
+
+
+if __name__ == '__main__':
+ testenv.main()
User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
User(name='ed', addresses=[Address(email_address='foo@bar.com')])
] == sess.query(User).all()
+
+ def test_rollback(self):
+ class Fixture(Base):
+ pass
+ mapper(User, users, properties={
+ 'addresses':dynamic_loader(mapper(Address, addresses))
+ })
+ sess = create_session(autoexpire=False, autocommit=False, autoflush=True)
+ u1 = User(name='jack')
+ u1.addresses.append(Address(email_address='lala@hoho.com'))
+ sess.save(u1)
+ sess.flush()
+ sess.commit()
+ u1.addresses.append(Address(email_address='foo@bar.com'))
+ self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
+ sess.rollback()
+ self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com')])
+
@testing.fails_on('maxdb')
def test_delete_nocascade(self):
mapper(User, users, properties={
from testlib import *
from testlib.fixtures import *
from query import QueryTest
+from sqlalchemy.orm import attributes
class EagerTest(FixtureTest):
keep_mappers = False
sess = create_session()
user = sess.query(User).get(7)
- assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True)
- assert not class_mapper(Address)._is_orphan(user.addresses[0])
+ assert getattr(User, 'addresses').hasparent(attributes.instance_state(user.addresses[0]), optimistic=True)
+ assert not class_mapper(Address)._is_orphan(attributes.instance_state(user.addresses[0]))
def test_orderby(self):
mapper(User, users, properties = {
})
mapper(User, users)
- assert [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))] == create_session().query(Address).filter(Address.id.in_([1, 4, 5])).all()
-
- assert [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))] == create_session().query(Address).filter(Address.id.in_([1, 4, 5])).limit(3).all()
-
sess = create_session()
- a = sess.query(Address).get(1)
+
+ for q in [
+ sess.query(Address).filter(Address.id.in_([1, 4, 5])),
+ sess.query(Address).filter(Address.id.in_([1, 4, 5])).limit(3)
+ ]:
+ sess.clear()
+ self.assertEquals(q.all(),
+ [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))]
+ )
+
+ a = sess.query(Address).filter(Address.id==1).first()
def go():
assert a.user_id==7
# assert that the eager loader added 'user_id' to the row
'user_id':deferred(addresses.c.user_id),
})
mapper(User, users, properties={'addresses':relation(Address, lazy=False)})
+
+ for q in [
+ sess.query(User).filter(User.id==7),
+ sess.query(User).filter(User.id==7).limit(1)
+ ]:
+ sess.clear()
+ self.assertEquals(q.all(),
+ [User(id=7, addresses=[Address(id=1)])]
+ )
- assert [User(id=7, addresses=[Address(id=1)])] == create_session().query(User).filter(User.id==7).all()
-
- assert [User(id=7, addresses=[Address(id=1)])] == create_session().query(User).limit(1).filter(User.id==7).all()
-
- sess = create_session()
+ sess.clear()
u = sess.query(User).get(7)
def go():
assert u.addresses[0].user_id==7
mapper(Dingaling, dingalings, properties={
'address_id':deferred(dingalings.c.address_id)
})
- sess = create_session()
+ sess.clear()
def go():
- u = sess.query(User).limit(1).get(8)
+ u = sess.query(User).get(8)
assert User(id=8, addresses=[Address(id=2, dingalings=[Dingaling(id=1)]), Address(id=3), Address(id=4)]) == u
self.assert_sql_count(testing.db, go, 1)
self.assert_sql_count(testing.db, go, 1)
def go():
- assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(keywords.c.name == 'red').all()
+ assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(Keyword.name == 'red').all()
self.assert_sql_count(testing.db, go, 1)
def go():
- assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(keywords.c.name == 'red').all()
+ assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(Keyword.name == 'red').all()
self.assert_sql_count(testing.db, go, 1)
q = sess.query(User)
def go():
- l = q.filter(s.c.u2_id==User.c.id).distinct().all()
+ l = q.filter(s.c.u2_id==User.id).distinct().all()
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
sess = create_session()
q = sess.query(Item)
- l = q.filter((Item.c.description=='item 2') | (Item.c.description=='item 5') | (Item.c.description=='item 3')).\
+ l = q.filter((Item.description=='item 2') | (Item.description=='item 5') | (Item.description=='item 3')).\
order_by(Item.id).limit(2).all()
assert fixtures.item_keyword_result[1:3] == l
)
]
- def test_basic(self):
+ def test_mapper_configured(self):
mapper(User, users, properties={
'addresses':relation(Address, lazy=False),
'orders':relation(Order)
sess = create_session()
+ oalias = aliased(Order)
def go():
- ret = sess.query(User).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+ ret = sess.query(User, oalias).join(('orders', oalias)).order_by(User.id, oalias.id).all()
self.assertEquals(ret, self._assert_result())
self.assert_sql_count(testing.db, go, 1)
sess = create_session()
+ oalias = aliased(Order)
def go():
- ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+ ret = sess.query(User, oalias).options(eagerload('addresses')).join(('orders', oalias)).order_by(User.id, oalias.id).all()
self.assertEquals(ret, self._assert_result())
self.assert_sql_count(testing.db, go, 6)
sess.clear()
def go():
- ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).options(eagerload('items', Order)).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+ ret = sess.query(User, oalias).options(eagerload('addresses'), eagerload(oalias.items)).join(('orders', oalias)).order_by(User.id, oalias.id).all()
self.assertEquals(ret, self._assert_result())
self.assert_sql_count(testing.db, go, 1)
sess.flush()
sess.clear()
-# l = sess.query(Widget).filter(Widget.name=='w1').all()
-# print l
assert [Widget(name='w1', children=[Widget(name='w2')])] == sess.query(Widget).filter(Widget.name==u'w1').all()
+class MixedEntitiesTest(FixtureTest, AssertsCompiledSQL):
+ keep_mappers = True
+ keep_data = True
+
+ def setup_mappers(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user'),
+ 'orders':relation(Order, backref='user'), # o2m, m2o
+ })
+ mapper(Address, addresses)
+ mapper(Order, orders, properties={
+ 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m
+ })
+ mapper(Item, items, properties={
+ 'keywords':relation(Keyword, secondary=item_keywords) #m2m
+ })
+ mapper(Keyword, keywords)
+
+ def test_two_entities(self):
+ sess = create_session()
+
+ # two FROM clauses
+ def go():
+ self.assertEquals(
+ [
+ (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+ (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+ ],
+ sess.query(User, Order).filter(User.id==Order.user_id).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).all(),
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ # one FROM clause
+ def go():
+ self.assertEquals(
+ [
+ (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+ (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+ ],
+ sess.query(User, Order).join(User.orders).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).all(),
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ def test_aliased_entity(self):
+ sess = create_session()
+
+ oalias = aliased(Order)
+
+ # two FROM clauses
+ def go():
+ self.assertEquals(
+ [
+ (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+ (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+ ],
+ sess.query(User, oalias).filter(User.id==oalias.user_id).options(eagerload(User.addresses), eagerload(oalias.items)).filter(User.id==9).all(),
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ # one FROM clause
+ def go():
+ self.assertEquals(
+ [
+ (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+ (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+ ],
+ sess.query(User, oalias).join((User.orders, oalias)).options(eagerload(User.addresses), eagerload(oalias.items)).filter(User.id==9).all(),
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ from sqlalchemy.engine.default import DefaultDialect
+
+ # improper setup: oalias in the columns clause but join to usual orders alias.
+ # this should create two FROM clauses even though the query has a from_clause set up via the join
+ self.assert_compile(sess.query(User, oalias).join(User.orders).options(eagerload(oalias.items)).with_labels().statement,
+ "SELECT users.id AS users_id, users.name AS users_name, orders_1.id AS orders_1_id, "\
+ "orders_1.user_id AS orders_1_user_id, orders_1.address_id AS orders_1_address_id, "\
+ "orders_1.description AS orders_1_description, orders_1.isopen AS orders_1_isopen, items_1.id AS items_1_id, "\
+ "items_1.description AS items_1_description FROM users JOIN orders ON users.id = orders.user_id, "\
+ "orders AS orders_1 LEFT OUTER JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id "\
+ "LEFT OUTER JOIN items AS items_1 ON items_1.id = order_items_1.item_id ORDER BY users.id, items_1.id",
+ dialect=DefaultDialect()
+ )
+
class CyclicalInheritingEagerTest(ORMTest):
+
def define_tables(self, metadata):
global t1, t2
t1 = Table('t1', metadata,
session.save(User(name='bar', tags=[Tag(score1=5.0, score2=4.0), Tag(score1=50.0, score2=1.0), Tag(score1=15.0, score2=2.0)]))
session.flush()
session.clear()
+
+ for user in session.query(User).all():
+ self.assertEquals(user.query_score, user.prop_score)
def go():
- for user in session.query(User).all():
- self.assertEquals(user.query_score, user.prop_score)
- self.assert_sql_count(testing.db, go, 1)
-
-
- # fails for non labeled (fixed in 0.5):
- if labeled:
- def go():
- u = session.query(User).filter_by(name='joe').one()
- self.assertEquals(u.query_score, u.prop_score)
- self.assert_sql_count(testing.db, go, 1)
- else:
u = session.query(User).filter_by(name='joe').one()
self.assertEquals(u.query_score, u.prop_score)
+ self.assert_sql_count(testing.db, go, 1)
for t in (tags_table, users_table):
t.delete().execute()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
from testlib import *
from testlib.tables import *
+from testlib import fixtures
class EntityTest(TestBase, AssertsExecutionResults):
"""tests mappers that are constructed based on "entity names", which allows the same class
to have multiple primary mappers """
- @testing.uses_deprecated('SessionContext')
def setUpAll(self):
global user1, user2, address1, address2, metadata, ctx
metadata = MetaData(testing.db)
- ctx = SessionContext(create_session)
+ ctx = scoped_session(create_session)
user1 = Table('user1', metadata,
Column('user_id', Integer, Sequence('user1_id_seq', optional=True),
def tearDownAll(self):
metadata.drop_all()
def tearDown(self):
- ctx.current.clear()
+ ctx.clear()
clear_mappers()
for t in metadata.table_iterator(reverse=True):
t.delete().execute()
- @testing.uses_deprecated('SessionContextExt')
def testbasic(self):
"""tests a pair of one-to-many mapper structures, establishing that both
parent and child objects honor the "entity_name" attribute attached to the object
instances."""
- class User(object):pass
- class Address(object):pass
+ class User(object):
+ def __init__(self, **kw):
+ pass
+ class Address(object):
+ def __init__(self, **kw):
+ pass
- a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.mapper_extension)
- a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension)
+ a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.extension)
+ a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.extension)
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'addresses':relation(a1mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'addresses':relation(a2mapper)
- }, extension=ctx.mapper_extension)
-
+ }, extension=ctx.extension)
+
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
a1 = Address(_sa_entity_name='address1')
a2.email='a2@foo.com'
u2.addresses.append(a2)
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')]
assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
- u1 = ctx.current.query(User, entity_name='user1').first()
- ctx.current.refresh(u1)
- ctx.current.expire(u1)
+ u1 = ctx.query(User, entity_name='user1').first()
+ ctx.refresh(u1)
+ ctx.expire(u1)
def testcascade(self):
def testpolymorphic(self):
"""tests that entity_name can be used to have two kinds of relations on the same class."""
- class User(object):pass
- class Address1(object):pass
- class Address2(object):pass
+ class User(object):
+ def __init__(self, **kw):
+ pass
+ class Address1(object):
+ def __init__(self, **kw):
+ pass
+ class Address2(object):
+ def __init__(self, **kw):
+ pass
- a1mapper = mapper(Address1, address1, extension=ctx.mapper_extension)
- a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension)
+ a1mapper = mapper(Address1, address1, extension=ctx.extension)
+ a2mapper = mapper(Address2, address2, extension=ctx.extension)
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'addresses':relation(a1mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'addresses':relation(a2mapper)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
a2.email='a2@foo.com'
u2.addresses.append(a2)
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')]
assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
def testpolymorphic_deferred(self):
"""test that deferred columns load properly using entity names"""
- class User(object):pass
+ class User(object):
+ def __init__(self, **kwargs):
+ pass
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'name':deferred(user1.c.name)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'name':deferred(user2.c.name)
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
u2 = User(_sa_entity_name='user2')
u2.name='this is user 2'
- ctx.current.flush()
+ ctx.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
- ctx.current.clear()
- u1list = ctx.current.query(User, entity_name='user1').all()
- u2list = ctx.current.query(User, entity_name='user2').all()
+ ctx.clear()
+ u1list = ctx.query(User, entity_name='user1').all()
+ u2list = ctx.query(User, entity_name='user2').all()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
# the deferred column load requires that setup_loader() check that the correct DeferredColumnLoader
assert u1list[0].name == 'this is user 1'
assert u2list[0].name == 'this is user 2'
+class SelfReferentialTest(ORMTest):
+ def define_tables(self, metadata):
+ global nodes
+
+ nodes = Table('nodes', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('nodes.id')),
+ Column('data', String(50)),
+ Column('type', String(50)),
+ )
+
+ # fails inconsistently. entity name needs deterministic
+ # instrumentation.
+ def dont_test_relation(self):
+ class Node(fixtures.Base):
+ pass
+
+ foonodes = nodes.select().where(nodes.c.type=='foo').alias()
+ barnodes = nodes.select().where(nodes.c.type=='bar').alias()
+
+ # TODO: the order of instrumentation here is not deterministic;
+ # therefore the test fails sporadically since "Node.data" references
+ # different mappers at different times
+ m1 = mapper(Node, nodes)
+ m2 = mapper(Node, foonodes, entity_name='foo')
+ m3 = mapper(Node, barnodes, entity_name='bar')
+
+ m1.add_property('foonodes', relation(m2, primaryjoin=nodes.c.id==foonodes.c.parent_id,
+ backref=backref('foo_parent', remote_side=nodes.c.id, primaryjoin=nodes.c.id==foonodes.c.parent_id)))
+ m1.add_property('barnodes', relation(m3, primaryjoin=nodes.c.id==barnodes.c.parent_id,
+ backref=backref('bar_parent', remote_side=nodes.c.id, primaryjoin=nodes.c.id==barnodes.c.parent_id)))
+
+ sess = create_session()
+
+ n1 = Node(data='n1', type='bat')
+ n1.foonodes.append(Node(data='n2', type='foo'))
+ Node(data='n3', type='bar', bar_parent=n1)
+ sess.save(n1)
+ sess.flush()
+ sess.clear()
+
+ self.assertEquals(sess.query(Node, entity_name="bar").one(), Node(data='n3'))
+ self.assertEquals(sess.query(Node).filter(Node.data=='n1').one(), Node(data='n1', foonodes=[Node(data='n2')], barnodes=[Node(data='n3')]))
if __name__ == "__main__":
testenv.main()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
from testlib import *
from testlib.fixtures import *
import gc
sess.expire(u)
# object isnt refreshed yet, using dict to bypass trigger
assert u.__dict__.get('name') != 'jack'
- assert 'name' in u._state.expired_attributes
+ assert 'name' in attributes.instance_state(u).expired_attributes
sess.query(User).all()
# test that it refreshed
assert u.__dict__['name'] == 'jack'
- assert 'name' not in u._state.expired_attributes
+ assert 'name' not in attributes.instance_state(u).expired_attributes
def go():
assert u.name == 'jack'
u = s.get(User, 7)
s.clear()
- self.assertRaisesMessage(exceptions.InvalidRequestError, r"is not persistent within this Session", lambda: s.expire(u))
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u)
+
+ def test_get_refreshes(self):
+ mapper(User, users)
+ s = create_session()
+ u = s.get(User, 10)
+ s.expire_all()
+ def go():
+ u = s.get(User, 10) # get() refreshes
+ self.assert_sql_count(testing.db, go, 1)
+ def go():
+ self.assertEquals(u.name, 'chuck') # attributes unexpired
+ self.assert_sql_count(testing.db, go, 0)
+ def go():
+ u = s.get(User, 10) # expire flag reset, so not expired
+ self.assert_sql_count(testing.db, go, 0)
+
+ s.expire_all()
+ users.delete().where(User.id==10).execute()
+
+ # object is gone, get() returns None
+ assert u in s
+ assert s.get(User, 10) is None
+ assert u not in s # and expunges
+
+ # add it back
+ s.add(u)
+ # nope, raises ObjectDeletedError
+ self.assertRaises(orm_exc.ObjectDeletedError, getattr, u, 'name')
+
+ def test_refresh_cancels_expire(self):
+ mapper(User, users)
+ s = create_session()
+ u = s.get(User, 7)
+ s.expire(u)
+ s.refresh(u)
+
+ def go():
+ u = s.get(User, 7)
+ self.assertEquals(u.name, 'jack')
+ self.assert_sql_count(testing.db, go, 0)
+
def test_expire_doesntload_on_set(self):
mapper(User, users)
sess.expire(u, attribute_names=['name'])
sess.expunge(u)
- try:
- u.name
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Instance <class 'testlib.fixtures.User'> is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed"
+ self.assertRaises(sa_exc.UnboundExecutionError, getattr, u, 'name')
- def test_pending_doesnt_raise(self):
+ def test_pending_raises(self):
+ # this was the opposite in 0.4, but the reasoning there seemed off.
+ # expiring a pending instance makes no sense, so should raise
mapper(User, users)
sess = create_session()
u = User(id=15)
sess.save(u)
- sess.expire(u, ['name'])
- assert u.name is None
+ self.assertRaises(sa_exc.InvalidRequestError, sess.expire, u, ['name'])
def test_no_instance_key(self):
# this tests an artificial condition such that
sess.expire(u, attribute_names=['name'])
sess.expunge(u)
- del u._instance_key
+ attributes.instance_state(u).key = None
assert 'name' not in u.__dict__
sess.save(u)
assert u.name == 'jack'
-
+
def test_expire_preserves_changes(self):
"""test that the expire load operation doesn't revert post-expire changes"""
orders.update(id=3).execute(description='order 3 modified')
assert o.isopen == 1
- assert o._state.dict['description'] == 'order 3 modified'
+ assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
def go():
sess.flush()
self.assert_sql_count(testing.db, go, 0)
u.addresses[0].email_address = 'someotheraddress'
s.expire(u)
u.name
- print u._state.dict
+ print attributes.instance_state(u).dict
assert u.addresses[0].email_address == 'ed@wood.com'
def test_expired_lazy(self):
sess.expire(o, attribute_names=['description'])
assert 'id' in o.__dict__
assert 'description' not in o.__dict__
- assert o._state.dict['isopen'] == 1
+ assert attributes.instance_state(o).dict['isopen'] == 1
orders.update(orders.c.id==3).execute(description='order 3 modified')
def go():
assert o.description == 'order 3 modified'
self.assert_sql_count(testing.db, go, 1)
- assert o._state.dict['description'] == 'order 3 modified'
+ assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
o.isopen = 5
sess.expire(o, attribute_names=['description'])
assert 'id' in o.__dict__
assert 'description' not in o.__dict__
assert o.__dict__['isopen'] == 5
- assert o._state.committed_state['isopen'] == 1
+ assert attributes.instance_state(o).committed_state['isopen'] == 1
def go():
assert o.description == 'order 3 modified'
self.assert_sql_count(testing.db, go, 1)
assert o.__dict__['isopen'] == 5
- assert o._state.dict['description'] == 'order 3 modified'
- assert o._state.committed_state['isopen'] == 1
+ assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
+ assert attributes.instance_state(o).committed_state['isopen'] == 1
sess.flush()
{'person_id':3, 'status':'old engineer'},
)
- def test_poly_select(self):
- mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
- mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
-
- sess = create_session()
- [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all()
-
- sess.expire(p1)
- sess.expire(e1, ['status'])
- sess.expire(e2)
-
- for p in [p1, e2]:
- assert 'name' not in p.__dict__
-
- assert 'name' in e1.__dict__
- assert 'status' not in e2.__dict__
- assert 'status' not in e1.__dict__
-
- e1.name = 'new engineer name'
-
- def go():
- sess.query(Person).all()
- self.assert_sql_count(testing.db, go, 3)
-
- for p in [p1, e1, e2]:
- assert 'name' in p.__dict__
-
- assert 'status' in e2.__dict__
- assert 'status' in e1.__dict__
- def go():
- assert e1.name == 'new engineer name'
- assert e2.name == 'engineer2'
- assert e1.status == 'new engineer'
- self.assert_sql_count(testing.db, go, 0)
- self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1']))
-
def test_poly_deferred(self):
- mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person', polymorphic_fetch='deferred')
+ mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
sess = create_session()
s = create_session()
u = s.get(User, 7)
s.clear()
- self.assertRaisesMessage(exceptions.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
def test_refresh_expired(self):
mapper(User, users)
--- /dev/null
+import testenv; testenv.configure_for_tests()
+import pickle
+from sqlalchemy import util
+import sqlalchemy.orm.attributes as attributes
+from sqlalchemy.orm.collections import collection
+from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented
+from sqlalchemy.orm import clear_mappers
+from sqlalchemy.orm import InstrumentationManager
+
+from testlib import *
+
+class MyTypesManager(InstrumentationManager):
+
+ def instrument_attribute(self, class_, key, attr):
+ pass
+
+ def install_descriptor(self, class_, key, attr):
+ pass
+
+ def uninstall_descriptor(self, class_, key):
+ pass
+
+ def instrument_collection_class(self, class_, key, collection_class):
+ return MyListLike
+
+ def get_instance_dict(self, class_, instance):
+ return instance._goofy_dict
+
+ def initialize_instance_dict(self, class_, instance):
+ instance.__dict__['_goofy_dict'] = {}
+
+ def install_state(self, class_, instance, state):
+ instance.__dict__['_my_state'] = state
+
+ def state_getter(self, class_):
+ return lambda instance: instance.__dict__['_my_state']
+
+class MyListLike(list):
+ # add @appender, @remover decorators as needed
+ _sa_iterator = list.__iter__
+ def _sa_appender(self, item, _sa_initiator=None):
+ if _sa_initiator is not False:
+ self._sa_adapter.fire_append_event(item, _sa_initiator)
+ list.append(self, item)
+ append = _sa_appender
+ def _sa_remover(self, item, _sa_initiator=None):
+ self._sa_adapter.fire_pre_remove_event(_sa_initiator)
+ if _sa_initiator is not False:
+ self._sa_adapter.fire_remove_event(item, _sa_initiator)
+ list.remove(self, item)
+ remove = _sa_remover
+
+class MyBaseClass(object):
+ __sa_instrumentation_manager__ = InstrumentationManager
+
+class MyClass(object):
+
+ # This proves that a staticmethod will work here; don't
+ # flatten this back to a class assignment!
+ def __sa_instrumentation_manager__(cls):
+ return MyTypesManager(cls)
+
+ __sa_instrumentation_manager__ = staticmethod(__sa_instrumentation_manager__)
+
+ # This proves SA can handle a class with non-string dict keys
+ locals()[42] = 99 # Don't remove this line!
+
+ def __init__(self, **kwargs):
+ for k in kwargs:
+ setattr(self, k, kwargs[k])
+
+ def __getattr__(self, key):
+ if is_instrumented(self, key):
+ return get_attribute(self, key)
+ else:
+ try:
+ return self._goofy_dict[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __setattr__(self, key, value):
+ if is_instrumented(self, key):
+ set_attribute(self, key, value)
+ else:
+ self._goofy_dict[key] = value
+
+ def __hasattr__(self, key):
+ if is_instrumented(self, key):
+ return True
+ else:
+ return key in self._goofy_dict
+
+ def __delattr__(self, key):
+ if is_instrumented(self, key):
+ del_attribute(self, key)
+ else:
+ del self._goofy_dict[key]
+
+class UserDefinedExtensionTest(TestBase):
+ def tearDownAll(self):
+ clear_mappers()
+ attributes._install_lookup_strategy(util.symbol('native'))
+
+ def test_basic(self):
+ for base in (object, MyBaseClass, MyClass):
+ class User(base):
+ pass
+
+ attributes.register_class(User)
+ attributes.register_attribute(User, 'user_id', uselist = False, useobject=False)
+ attributes.register_attribute(User, 'user_name', uselist = False, useobject=False)
+ attributes.register_attribute(User, 'email_address', uselist = False, useobject=False)
+
+ u = User()
+ u.user_id = 7
+ u.user_name = 'john'
+ u.email_address = 'lala@123.com'
+
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+ attributes.instance_state(u).commit_all()
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+ u.user_name = 'heythere'
+ u.email_address = 'foo@bar.com'
+ self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
+
+ def test_deferred(self):
+ for base in (object, MyBaseClass, MyClass):
+ class Foo(base):pass
+
+ data = {'a':'this is a', 'b':12}
+ def loader(state, keys):
+ for k in keys:
+ state.dict[k] = data[k]
+ return attributes.ATTR_WAS_SET
+
+ attributes.register_class(Foo)
+ manager = attributes.manager_of_class(Foo)
+ manager.deferred_scalar_loader = loader
+ attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
+ attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+
+ assert Foo in attributes.instrumentation_registry.state_finders
+ f = Foo()
+ attributes.instance_state(f).expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ f.a = "this is some new a"
+ attributes.instance_state(f).expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ attributes.instance_state(f).expire_attributes(None)
+ f.a = "this is another new a"
+ self.assertEquals(f.a, "this is another new a")
+ self.assertEquals(f.b, 12)
+
+ attributes.instance_state(f).expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ del f.a
+ self.assertEquals(f.a, None)
+ self.assertEquals(f.b, 12)
+
+ attributes.instance_state(f).commit_all()
+ self.assertEquals(f.a, None)
+ self.assertEquals(f.b, 12)
+
+ def test_inheritance(self):
+ """tests that attributes are polymorphic"""
+
+ for base in (object, MyBaseClass, MyClass):
+ class Foo(base):pass
+ class Bar(Foo):pass
+
+ attributes.register_class(Foo)
+ attributes.register_class(Bar)
+
+ def func1():
+ print "func1"
+ return "this is the foo attr"
+ def func2():
+ print "func2"
+ return "this is the bar attr"
+ def func3():
+ print "func3"
+ return "this is the shared attr"
+ attributes.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1, useobject=True)
+ attributes.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3, useobject=True)
+ attributes.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2, useobject=True)
+
+ x = Foo()
+ y = Bar()
+ assert x.element == 'this is the foo attr'
+ assert y.element == 'this is the bar attr', y.element
+ assert x.element2 == 'this is the shared attr'
+ assert y.element2 == 'this is the shared attr'
+
+ def test_collection_with_backref(self):
+ for base in (object, MyBaseClass, MyClass):
+ class Post(base):pass
+ class Blog(base):pass
+
+ attributes.register_class(Post)
+ attributes.register_class(Blog)
+ attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
+ attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
+ b = Blog()
+ (p1, p2, p3) = (Post(), Post(), Post())
+ b.posts.append(p1)
+ b.posts.append(p2)
+ b.posts.append(p3)
+ self.assert_(b.posts == [p1, p2, p3])
+ self.assert_(p2.blog is b)
+
+ p3.blog = None
+ self.assert_(b.posts == [p1, p2])
+ p4 = Post()
+ p4.blog = b
+ self.assert_(b.posts == [p1, p2, p4])
+
+ p4.blog = b
+ p4.blog = b
+ self.assert_(b.posts == [p1, p2, p4])
+
+ # assert no failure removing None
+ p5 = Post()
+ p5.blog = None
+ del p5.blog
+
+ def test_history(self):
+ for base in (object, MyBaseClass, MyClass):
+ class Foo(base):
+ pass
+ class Bar(base):
+ pass
+
+ attributes.register_class(Foo)
+ attributes.register_class(Bar)
+ attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
+ attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
+ attributes.register_attribute(Bar, "name", uselist=False, useobject=False)
+
+
+ f1 = Foo()
+ f1.name = 'f1'
+
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], [], []))
+
+ b1 = Bar()
+ b1.name = 'b1'
+ f1.bars.append(b1)
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
+
+ attributes.instance_state(f1).commit_all()
+ attributes.instance_state(b1).commit_all()
+
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ([], ['f1'], []))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([], [b1], []))
+
+ f1.name = 'f1mod'
+ b2 = Bar()
+ b2.name = 'b2'
+ f1.bars.append(b2)
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], [], ['f1']))
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
+ f1.bars.remove(b1)
+ self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
+
+ def test_null_instrumentation(self):
+ class Foo(MyBaseClass):
+ pass
+ attributes.register_class(Foo)
+ attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
+ attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
+
+ assert Foo.name == attributes.manager_of_class(Foo).get_inst('name')
+ assert Foo.bars == attributes.manager_of_class(Foo).get_inst('bars')
+
+ def test_alternate_finders(self):
+ """Ensure the generic finder front-end deals with edge cases."""
+
+ class Unknown(object): pass
+ class Known(MyBaseClass): pass
+
+ attributes.register_class(Known)
+ k, u = Known(), Unknown()
+
+ assert attributes.manager_of_class(Unknown) is None
+ assert attributes.manager_of_class(Known) is not None
+ assert attributes.manager_of_class(None) is None
+
+ assert attributes.instance_state(k) is not None
+ self.assertRaises((AttributeError, KeyError),
+ attributes.instance_state, u)
+ self.assertRaises((AttributeError, KeyError),
+ attributes.instance_state, None)
+
+
+if __name__ == '__main__':
+ testing.main()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
from testlib import *
import testlib.tables as tables
def test_selectby(self):
res = create_session(bind=testing.db).query(Foo).filter_by(range=5)
- assert res.order_by([Foo.c.bar])[0].bar == 5
- assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
+ assert res.order_by([Foo.bar])[0].bar == 5
+ assert res.order_by([desc(Foo.bar)])[0].bar == 95
@testing.unsupported('mssql')
@testing.fails_on('maxdb')
assert query.count() == 100
assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
assert query.filter(foo.c.bar<30).max(foo.c.bar) == 29
- assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29
- assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29
+ assert query.filter(foo.c.bar<30).values(func.max(foo.c.bar)).next()[0] == 29
+ assert query.filter(foo.c.bar<30).values(func.max(foo.c.bar)).next()[0] == 29
def test_aggregate_1(self):
if (testing.against('mysql') and
avg = query.filter(foo.c.bar < 30).avg(foo.c.bar)
assert round(avg, 1) == 14.5
- @testing.fails_on('firebird', 'mssql')
- @testing.uses_deprecated('Call to deprecated function apply_avg')
def test_aggregate_3(self):
query = create_session(bind=testing.db).query(Foo)
- avg_f = query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first()
+ avg_f = query.filter(foo.c.bar<30).values(func.avg(foo.c.bar)).next()[0]
assert round(avg_f, 1) == 14.5
- avg_o = query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one()
+ avg_o = query.filter(foo.c.bar<30).values(func.avg(foo.c.bar)).next()[0]
assert round(avg_o, 1) == 14.5
def test_filter(self):
query = create_session(bind=testing.db).query(Foo)
assert query.count() == 100
- assert query.filter(Foo.c.bar < 30).count() == 30
- res2 = query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
+ assert query.filter(Foo.bar < 30).count() == 30
+ res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10)
assert res2.count() == 19
def test_options(self):
def test_order_by(self):
query = create_session(bind=testing.db).query(Foo)
- assert query.order_by([Foo.c.bar])[0].bar == 0
- assert query.order_by([desc(Foo.c.bar)])[0].bar == 99
+ assert query.order_by([Foo.bar])[0].bar == 0
+ assert query.order_by([desc(Foo.bar)])[0].bar == 99
def test_offset(self):
query = create_session(bind=testing.db).query(Foo)
- assert list(query.order_by([Foo.c.bar]).offset(10))[0].bar == 10
+ assert list(query.order_by([Foo.bar]).offset(10))[0].bar == 10
def test_offset(self):
query = create_session(bind=testing.db).query(Foo)
})
session = create_session(bind=testing.db)
query = session.query(tables.User)
- x = query.join(['orders', 'items']).filter(tables.Item.c.item_id==2)
+ x = query.join(['orders', 'items']).filter(tables.Item.item_id==2)
print x.compile()
self.assert_result(list(x), tables.User, tables.user_result[2])
def test_outerjointo(self):
})
session = create_session(bind=testing.db)
query = session.query(tables.User)
- x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+ x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.order_id==None,tables.Item.item_id==2))
print x.compile()
self.assert_result(list(x), tables.User, *tables.user_result[1:3])
def test_outerjointo_count(self):
})
session = create_session(bind=testing.db)
query = session.query(tables.User)
- x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
+ x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.order_id==None,tables.Item.item_id==2)).count()
assert x==2
def test_from(self):
mapper(tables.User, tables.users, properties={
session = create_session(bind=testing.db)
query = session.query(tables.User)
x = query.select_from(tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)).\
- filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+ filter(or_(tables.Order.order_id==None,tables.Item.item_id==2))
print x.compile()
self.assert_result(list(x), tables.User, *tables.user_result[1:3])
res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)).distinct()
self.assertEqual(res.count(), 1)
-class SelfRefTest(ORMTest):
- def define_tables(self, metadata):
- global t1
- t1 = Table('t1', metadata,
- Column('id', Integer, primary_key=True),
- Column('parent_id', Integer, ForeignKey('t1.id'))
- )
- def test_noautojoin(self):
- class T(object):pass
- mapper(T, t1, properties={'children':relation(T)})
- sess = create_session(bind=testing.db)
- def go():
- sess.query(T).join('children')
- self.assertRaisesMessage(exceptions.InvalidRequestError,
- "Self-referential query on 'T\.children \(T\)' property requires aliased=True argument.", go)
-
- def go():
- sess.query(T).join(['children']).select_by(id=7)
- self.assertRaisesMessage(exceptions.InvalidRequestError,
- "Self-referential query on 'T\.children \(T\)' property requires aliased=True argument.", go)
-
if __name__ == "__main__":
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
from testlib import *
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import util
from sqlalchemy.orm import *
from testlib import *
from testlib import fixtures
else:
abc = bc = None
- mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a', polymorphic_fetch=fetchtype)
- mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b', polymorphic_fetch=fetchtype)
+ mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a')
+ mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b')
mapper(C, c, inherits=B, polymorphic_identity='c')
a1 = A(adata='a1')
return test_roundtrip
test_union = make_test('union')
- test_select = make_test('select')
- test_deferred = make_test('deferred')
+ test_none = make_test('none')
if __name__ == '__main__':
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
from sqlalchemy.orm import *
+from sqlalchemy.orm import exc as orm_exc
from testlib import *
from testlib import fixtures
'content_type':relation(content_types)
}, polymorphic_identity='contents')
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument"
def testbackref(self):
class Admin(User):pass
role_mapper = mapper(Role, roles)
user_mapper = mapper(User, users, properties = {
- 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+ 'roles' : relation(Role, secondary=user_roles, lazy=False)
}
)
admin_mapper = mapper(Admin, admins, inherits=user_mapper)
role_mapper = mapper(Role, roles)
user_mapper = mapper(User, users, properties = {
- 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+ 'roles' : relation(Role, secondary=user_roles, lazy=False)
}
)
try:
sess2.query(Base).with_lockmode('read').get(s1.id)
assert False
- except exceptions.ConcurrentModificationError, e:
+ except orm_exc.ConcurrentModificationError, e:
assert True
try:
sess2.flush()
assert False
- except exceptions.ConcurrentModificationError, e:
+ except orm_exc.ConcurrentModificationError, e:
assert True
sess2.refresh(s2)
s1.subdata = 'some new subdata'
sess.flush()
assert False
- except exceptions.ConcurrentModificationError, e:
+ except orm_exc.ConcurrentModificationError, e:
assert True
mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
self._do_test(True)
assert False
- except exceptions.SAWarning, e:
+ except sa_exc.SAWarning, e:
assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e)
def test_explicit_pk(self):
assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Kurt knows how to hack"])
+ manager = session.query(Manager).one()
+ session.expire(manager, ['manager_data'])
+ self.assertEquals(manager.manager_data, "knows how to manage things")
+
def test_multi_level(self):
class Employee(object):
def __init__(self, name):
# clear and query forwards
sess.clear()
- node = sess.query(Table1).filter(Table1.c.id==t.id).first()
+ node = sess.query(Table1).filter(Table1.id==t.id).first()
assertlist = []
while (node):
assertlist.append(node)
# clear and query backwards
sess.clear()
- node = sess.query(Table1).filter(Table1.c.id==obj.id).first()
+ node = sess.query(Table1).filter(Table1.id==obj.id).first()
assertlist = []
while (node):
assertlist.insert(0, node)
backwards = repr(assertlist)
# everything should match !
- print "ORIGNAL", original
- print "BACKWARDS",backwards
- print "FORWARDS", forwards
assert original == forwards == backwards
if __name__ == '__main__':
import sets
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy.orm import exc as orm_exc
from testlib import *
from testlib import fixtures
class RoundTripTest(PolymorphTest):
pass
-def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, polymorphic_fetch=None, use_outer_joins=False):
+def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic):
"""generates a round trip test.
include_base - whether or not to include the base 'person' type in the union.
use_literal_join - primary join condition is explicitly specified
"""
def test_roundtrip(self):
- # create a union that represents both types of joins.
- if not polymorphic_fetch == 'union':
- person_join = None
- manager_join = None
- elif include_base:
- if use_outer_joins:
- person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
- manager_join = people.join(managers).outerjoin(boss)
- else:
+ if with_polymorphic == 'unions':
+ if include_base:
person_join = polymorphic_union(
{
'engineer':people.join(engineers),
'manager':people.join(managers),
'person':people.select(people.c.type=='person'),
}, None, 'pjoin')
-
- manager_join = people.join(managers).outerjoin(boss)
- else:
- if use_outer_joins:
- person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
- manager_join = people.join(managers).outerjoin(boss)
else:
person_join = polymorphic_union(
{
'engineer':people.join(engineers),
'manager':people.join(managers),
}, None, 'pjoin')
- manager_join = people.join(managers).outerjoin(boss)
+
+ manager_join = people.join(managers).outerjoin(boss)
+ person_with_polymorphic = ['*', person_join]
+ manager_with_polymorphic = ['*', manager_join]
+ elif with_polymorphic == 'joins':
+ person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
+ manager_join = people.join(managers).outerjoin(boss)
+ person_with_polymorphic = ['*', person_join]
+ manager_with_polymorphic = ['*', manager_join]
+ elif with_polymorphic == 'auto':
+ person_with_polymorphic = '*'
+ manager_with_polymorphic = '*'
+ else:
+ person_with_polymorphic = None
+ manager_with_polymorphic = None
if redefine_colprop:
- person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
+ person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
else:
- person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person')
+ person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person')
mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
- mapper(Manager, managers, inherits=person_mapper, select_table=manager_join, polymorphic_identity='manager')
+ mapper(Manager, managers, inherits=person_mapper, with_polymorphic=manager_with_polymorphic, polymorphic_identity='manager')
mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss')
- if use_literal_join:
- mapper(Company, companies, properties={
- 'employees': relation(Person, lazy=lazy_relation,
- primaryjoin=(people.c.company_id ==
- companies.c.company_id),
- cascade="all,delete-orphan",
- backref="company",
- order_by=people.c.person_id
- )
- })
- else:
- mapper(Company, companies, properties={
- 'employees': relation(Person, lazy=lazy_relation,
- cascade="all, delete-orphan",
- backref="company", order_by=people.c.person_id
- )
- })
+ mapper(Company, companies, properties={
+ 'employees': relation(Person, lazy=lazy_relation,
+ cascade="all, delete-orphan",
+ backref="company", order_by=people.c.person_id
+ )
+ })
if redefine_colprop:
person_attribute_name = 'person_name'
def go():
cc = session.query(Company).get(c.company_id)
- for e in cc.employees:
- assert e._instance_key[0] == Person
self.assertEquals(cc.employees, employees)
if not lazy_relation:
- if polymorphic_fetch=='union':
+ if with_polymorphic != 'none':
self.assert_sql_count(testing.db, go, 1)
else:
self.assert_sql_count(testing.db, go, 5)
else:
- if polymorphic_fetch=='union':
+ if with_polymorphic != 'none':
self.assert_sql_count(testing.db, go, 2)
else:
self.assert_sql_count(testing.db, go, 6)
session.flush()
session.clear()
- if polymorphic_fetch == 'select':
- def go():
- session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
- self.assert_sql_count(testing.db, go, 2)
- session.clear()
- dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
- def go():
- # assert that only primary table is queried for already-present-in-session
- d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
- self.assert_sql_count(testing.db, go, 1)
+ def go():
+ session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+ self.assert_sql_count(testing.db, go, 1)
+ session.clear()
+ dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+ def go():
+ # assert that only primary table is queried for already-present-in-session
+ d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+ self.assert_sql_count(testing.db, go, 1)
# test standalone orphans
daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'})
session.save(daboss)
- self.assertRaises(exceptions.FlushError, session.flush)
+ self.assertRaises(orm_exc.FlushError, session.flush)
c = session.query(Company).first()
daboss.company = c
manager_list = [e for e in c.employees if isinstance(e, Manager)]
self.assertEquals(people.count().scalar(), 0)
test_roundtrip = _function_named(
- test_roundtrip, "test_%s%s%s%s%s" % (
+ test_roundtrip, "test_%s%s%s_%s" % (
(lazy_relation and "lazy" or "eager"),
(include_base and "_inclbase" or ""),
(redefine_colprop and "_redefcol" or ""),
- (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
- (use_outer_joins and '_outerjoins' or '')))
+ with_polymorphic))
setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
-for include_base in [True, False]:
- for lazy_relation in [True, False]:
- for redefine_colprop in [True, False]:
- for use_literal_join in [True, False]:
- for polymorphic_fetch in ['union', 'select', 'deferred']:
- if polymorphic_fetch == 'union':
- for use_outer_joins in [True, False]:
- generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, use_outer_joins)
- else:
- generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, False)
+for lazy_relation in [True, False]:
+ for redefine_colprop in [True, False]:
+ for with_polymorphic in ['unions', 'joins', 'auto', 'none']:
+ if with_polymorphic == 'unions':
+ for include_base in [True, False]:
+ generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic)
+ else:
+ generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic)
if __name__ == "__main__":
testenv.main()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import util
from sqlalchemy.orm import *
from testlib import *
from testlib import fixtures
class Car(PersistentObject):
def __repr__(self):
- return "Car number %d, name %s" % i(self.car_id, self.name)
+ return "Car number %d, name %s" % (self.car_id, self.name)
class Offraod_Car(Car):
def __repr__(self):
session.save(car2)
session.flush()
- # test these twice because theres caching involved, as well previous issues that modified the polymorphic union
- for x in range(0, 2):
- r = session.query(Person).filter(people.c.name.like('%2')).join('status').filter_by(name="active")
- assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
- r = session.query(Engineer).join('status').filter(people.c.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active"))
- assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
- # this test embeds the original polymorphic union (employee_join) fully
- # into the WHERE criterion, using a correlated select. ticket #577 tracks
- # that Query's adaptation of the WHERE clause does not dig into the
- # mapped selectable itself, which permanently breaks the mapped selectable.
- r = session.query(Person).filter(exists([Car.c.owner], Car.c.owner==employee_join.c.person_id))
- assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
+ # this particular adapt used to cause a recursion overflow;
+ # added here for testing
+ e = exists([Car.owner], Car.owner==employee_join.c.person_id)
+ Query(Person)._adapt_clause(employee_join, False, False)
+
+ r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active")
+ assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
+ r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active"))
+ assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
+
+ r = session.query(Person).filter(exists([1], Car.owner==Person.person_id))
+ assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
class MultiLevelTest(ORMTest):
def define_tables(self, metadata):
import sets
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from testlib import *
from testlib import fixtures
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.engine import default
class Company(fixtures.Base):
pass
pass
def make_test(select_type):
- class PolymorphicQueryTest(ORMTest):
+ class PolymorphicQueryTest(ORMTest, AssertsCompiledSQL):
keep_data = True
keep_mappers = True
def test_primary_eager_aliasing(self):
sess = create_session()
+
+ # assert the SQL itself here to ensure no over-joining is taking place
+ if select_type == '':
+ self.assert_compile(
+ sess.query(Person).options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().statement,
+ "SELECT people.person_id AS people_person_id, people.company_id AS people_company_id, "\
+ "people.name AS people_name, people.type AS people_type FROM people ORDER BY people.person_id LIMIT 2 OFFSET 1",
+ dialect=default.DefaultDialect())
+
def go():
self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3].all(), all_employees[1:3])
self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4))
sess = create_session()
+
+ if select_type == '':
+ self.assert_compile(
+ sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().statement,
+ "SELECT anon_1.people_person_id AS anon_1_people_person_id, anon_1.people_company_id AS anon_1_people_company_id, "\
+ "anon_1.people_name AS anon_1_people_name, anon_1.people_type AS anon_1_people_type, anon_1.engineers_person_id AS "\
+ "anon_1_engineers_person_id, anon_1.engineers_status AS anon_1_engineers_status, anon_1.engineers_engineer_name AS "\
+ "anon_1_engineers_engineer_name, anon_1.engineers_primary_language AS anon_1_engineers_primary_language, "\
+ "anon_1.managers_person_id AS anon_1_managers_person_id, anon_1.managers_status AS anon_1_managers_status, "\
+ "anon_1.managers_manager_name AS anon_1_managers_manager_name, anon_1.boss_boss_id AS anon_1_boss_boss_id, "\
+ "anon_1.boss_golf_swing AS anon_1_boss_golf_swing, machines_1.machine_id AS machines_1_machine_id, machines_1.name AS "\
+ "machines_1_name, machines_1.engineer_id AS machines_1_engineer_id FROM (SELECT people.person_id AS people_person_id, "\
+ "people.company_id AS people_company_id, people.name AS people_name, people.type AS people_type, engineers.person_id AS "\
+ "engineers_person_id, engineers.status AS engineers_status, engineers.engineer_name AS engineers_engineer_name, "\
+ "engineers.primary_language AS engineers_primary_language, managers.person_id AS managers_person_id, managers.status "\
+ "AS managers_status, managers.manager_name AS managers_manager_name, boss.boss_id AS boss_boss_id, boss.golf_swing "\
+ "AS boss_golf_swing FROM people LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id LEFT OUTER JOIN "\
+ "managers ON people.person_id = managers.person_id LEFT OUTER JOIN boss ON managers.person_id = boss.boss_id ORDER BY "\
+ "people.person_id LIMIT 2 OFFSET 1) AS anon_1 LEFT OUTER JOIN machines AS machines_1 ON anon_1.engineers_person_id = "\
+ "machines_1.engineer_id ORDER BY anon_1.people_person_id, machines_1.machine_id",
+ dialect=default.DefaultDialect())
+
def go():
self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3].all(), all_employees[1:3])
self.assert_sql_count(testing.db, go, 3)
# for all mappers, ensure the primary key has been calculated as just the "person_id"
# column
- self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert"))
- self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert"))
- self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss"))
+ self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+ self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+ self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"))
def test_filter_on_subclass(self):
sess = create_session()
def test_join_from_polymorphic(self):
sess = create_session()
-
+
for aliased in (True, False):
self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
- self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+ self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
def test_join_from_with_polymorphic(self):
sess = create_session()
self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
sess.clear()
- self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+ self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
def test_join_to_polymorphic(self):
sess = create_session()
self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
-
+
def test_polymorphic_any(self):
sess = create_session()
Manager(name="dogbert", manager_name="dogbert", status="regular manager"),
Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer")
]
+ self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
+
def go():
self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
]
sess = create_session()
+
def go():
# test load Companies with lazy load to 'employees'
self.assertEquals(sess.query(Company).all(), assert_result)
# in the case of select_type='', the eagerload doesn't take in this case;
# it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines"
self.assert_sql_count(testing.db, go, {'':7, 'Polymorphic':1}.get(select_type, 2))
-
+
def test_eagerload_on_subclass(self):
sess = create_session()
def go():
def test_join_to_subclass(self):
sess = create_session()
+ self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
if select_type == '':
self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
+
+ ealias = aliased(Engineer)
+ self.assertEquals(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1])
+
self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3])
self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2])
self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2)
+ def test_from_alias(self):
+ sess = create_session()
+
+ palias = aliased(Person)
+ self.assertEquals(
+ sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(),
+ [e1, e2]
+ )
+
+ def test_self_referential(self):
+ sess = create_session()
+
+ c1_employees = [e1, e2, b1, m1]
+
+ palias = aliased(Person)
+ self.assertEquals(
+ sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+ filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(),
+ [
+ (m1, e1),
+ (m1, e2),
+ (m1, b1),
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+ filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(),
+ [
+ (m1, e1),
+ (m1, e2),
+ (m1, b1),
+ ]
+ )
+
+ def test_nesting_queries(self):
+ sess = create_session()
+
+ # query.statement places a flag "no_adapt" on the returned statement. This prevents
+ # the polymorphic adaptation in the second "filter" from hitting it, which would pollute
+ # the subquery and usually results in recursion overflow errors within the adaption.
+ subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar()
+
+ self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
+
+
+ def test_mixed_entities(self):
+ sess = create_session()
+
+ self.assertEquals(
+ sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+ [(u'Elbonia, Inc.',
+ Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))]
+ )
+
+ self.assertEquals(
+ sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+ [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
+ u'Elbonia, Inc.')]
+ )
+
+ self.assertEquals(
+ sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+ [(u'vlad',u'Elbonia, Inc.')]
+ )
+
+ self.assertEquals(
+ sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(),
+ [(u'java',), (u'c++',), (u'cobol',)]
+ )
+
+ if select_type != '':
+ self.assertEquals(
+ sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(),
+ [
+ (Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'),
+ (Engineer(status=u'regular engineer',engineer_name=u'wally',name=u'wally',company_id=1,primary_language=u'c++',person_id=2,type=u'engineer'), u'MegaCorp, Inc.'),
+ (Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',company_id=2,primary_language=u'cobol',person_id=5,type=u'engineer'), u'Elbonia, Inc.')
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(),
+ [(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')]
+ )
+
+ palias = aliased(Person)
+ self.assertEquals(
+ sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+ [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
+ u'Elbonia, Inc.',
+ Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))]
+ )
+
+ self.assertEquals(
+ sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+ [(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'),
+ u'Elbonia, Inc.',
+ Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),)
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+ [(u'vlad', u'Elbonia, Inc.', u'dilbert')]
+ )
+
+ palias = aliased(Person)
+ self.assertEquals(
+ sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+ filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(),
+ [(u'manager', u'dogbert', u'engineer', u'dilbert'),
+ (u'manager', u'dogbert', u'engineer', u'wally'),
+ (u'manager', u'dogbert', u'boss', u'pointy haired boss')]
+ )
+
+ self.assertEquals(
+ sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(),
+ [(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'),
+ (u'dogbert', u'review #3'),
+ (u'pointy haired boss', u'review #1'),
+ (u'vlad', u'elbonian missive #3'),
+ (u'wally', u'tps report #3'),
+ (u'wally', u'tps report #4'),
+ ]
+ )
+
+ if select_type != '':
+ self.assertEquals(
+ sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(),
+ [(1, )]
+ )
+
+ self.assertEquals(
+ sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(),
+ [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
+ )
+
+ self.assertEquals(
+ sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(),
+ [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
+ )
+
+
PolymorphicQueryTest.__name__ = "Polymorphic%sTest" % select_type
return PolymorphicQueryTest
self.assertEquals(sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), Engineer(name='dilbert'))
- def test_noalias_raises(self):
- sess = create_session()
- def go():
- sess.query(Engineer).join('reports_to')
- self.assertRaises(exceptions.InvalidRequestError, go)
class M2MFilterTest(ORMTest):
keep_mappers = True
sess = create_session()
self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')])
self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+
+class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL):
+ def define_tables(self, metadata):
+ Base = declarative_base(metadata=metadata)
+
+ secondary_table = Table('secondary', Base.metadata,
+ Column('left_id', Integer, ForeignKey('parent.id'), nullable=False),
+ Column('right_id', Integer, ForeignKey('parent.id'), nullable=False))
+
+ global Parent, Child1, Child2
+ class Parent(Base):
+ __tablename__ = 'parent'
+ id = Column(Integer, primary_key=True)
+ cls = Column(String(50))
+ __mapper_args__ = dict(polymorphic_on = cls )
+
+ class Child1(Parent):
+ __tablename__ = 'child1'
+ id = Column(Integer, ForeignKey('parent.id'), primary_key=True)
+ __mapper_args__ = dict(polymorphic_identity = 'child1')
+
+ class Child2(Parent):
+ __tablename__ = 'child2'
+ id = Column(Integer, ForeignKey('parent.id'), primary_key=True)
+ __mapper_args__ = dict(polymorphic_identity = 'child2')
+
+ Child1.left_child2 = relation(Child2, secondary = secondary_table,
+ primaryjoin = Parent.id == secondary_table.c.right_id,
+ secondaryjoin = Parent.id == secondary_table.c.left_id,
+ uselist = False,
+ )
+
+ def test_eager_join(self):
+ session = create_session()
+ c1 = Child1()
+ c1.left_child2 = Child2()
+ session.add(c1)
+ session.flush()
+
+ q = session.query(Child1).options(eagerload('left_child2'))
+
+ # test that the splicing of the join works here, doesnt break in the middle of "parent join child1"
+ self.assert_compile(q.limit(1).with_labels().statement,
+ "SELECT anon_1.child1_id AS anon_1_child1_id, anon_1.parent_id AS anon_1_parent_id, "\
+ "anon_1.parent_cls AS anon_1_parent_cls, anon_2.child2_id AS anon_2_child2_id, anon_2.parent_id AS anon_2_parent_id, "\
+ "anon_2.parent_cls AS anon_2_parent_cls FROM (SELECT child1.id AS child1_id, parent.id AS parent_id, "\
+ "parent.cls AS parent_cls, parent.id AS parent_oid FROM parent JOIN child1 ON parent.id = child1.id ORDER BY parent.id "\
+ "LIMIT 1) AS anon_1 LEFT OUTER JOIN secondary AS secondary_1 ON anon_1.parent_id = secondary_1.right_id LEFT OUTER JOIN "\
+ "(SELECT parent.id AS parent_id, parent.cls AS parent_cls, child2.id AS child2_id FROM parent JOIN child2 ON parent.id = child2.id) "\
+ "AS anon_2 ON anon_2.parent_id = secondary_1.left_id ORDER BY anon_1.child1_id"
+ , dialect=default.DefaultDialect())
+ assert q.first() is c1
+
if __name__ == "__main__":
testenv.main()
assert session.query(Engineer).all() == [e1, e2]
assert session.query(Manager).all() == [m1]
assert session.query(JuniorEngineer).all() == [e2]
-
+
+ m1 = session.query(Manager).one()
+ session.expire(m1, ['manager_data'])
+ self.assertEquals(m1.manager_data, "knows how to manage things")
+
if __name__ == '__main__':
testenv.main()
--- /dev/null
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import MetaData, Table, Column, Integer, ForeignKey
+from sqlalchemy import util
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import create_session
+from sqlalchemy.orm import interfaces
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relation
+
+from testlib.testing import eq_, ne_
+from testlib.compat import _function_named
+from testlib import TestBase
+
+
+def modifies_instrumentation_finders(fn):
+ def decorated(*args, **kw):
+ pristine = attributes.instrumentation_finders[:]
+ try:
+ fn(*args, **kw)
+ finally:
+ del attributes.instrumentation_finders[:]
+ attributes.instrumentation_finders.extend(pristine)
+ return _function_named(decorated, fn.func_name)
+
+def with_lookup_strategy(strategy):
+ def decorate(fn):
+ def wrapped(*args, **kw):
+ current = attributes._lookup_strategy
+ try:
+ attributes._install_lookup_strategy(strategy)
+ return fn(*args, **kw)
+ finally:
+ attributes._install_lookup_strategy(current)
+ return _function_named(wrapped, fn.func_name)
+ return decorate
+
+
+class InitTest(TestBase):
+ def fixture(self):
+ return Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column('type', Integer),
+ Column('x', Integer),
+ Column('y', Integer))
+
+ def register(self, cls, canary):
+ original_init = cls.__init__
+ attributes.register_class(cls)
+ ne_(cls.__init__, original_init)
+ manager = attributes.manager_of_class(cls)
+ def on_init(state, instance, args, kwargs):
+ canary.append((cls, 'on_init', type(instance)))
+ manager.events.add_listener('on_init', on_init)
+
+ def test_ai(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+
+ obj = A()
+ eq_(inits, [(A, '__init__')])
+
+ def test_A(self):
+ inits = []
+
+ class A(object): pass
+ self.register(A, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A)])
+
+ def test_Ai(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ def test_ai_B(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+
+ class B(A): pass
+ self.register(B, inits)
+
+ obj = A()
+ eq_(inits, [(A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+ def test_ai_Bi(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+ self.register(B, inits)
+
+ obj = A()
+ eq_(inits, [(A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+ def test_Ai_bi(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')])
+
+ def test_Ai_Bi(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+ self.register(B, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+ def test_Ai_B(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+ def test_Ai_Bi_Ci(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+ self.register(B, inits)
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ super(C, self).__init__()
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'),
+ (A, '__init__')])
+
+ def test_Ai_bi_Ci(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ super(B, self).__init__()
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ super(C, self).__init__()
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'),
+ (A, '__init__')])
+
+ def test_Ai_b_Ci(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A): pass
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ super(C, self).__init__()
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(A, 'on_init', B), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')])
+
+ def test_Ai_B_Ci(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ super(C, self).__init__()
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')])
+
+ def test_Ai_B_C(self):
+ inits = []
+
+ class A(object):
+ def __init__(self):
+ inits.append((A, '__init__'))
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ class C(B): pass
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (A, '__init__')])
+
+ def test_A_Bi_C(self):
+ inits = []
+
+ class A(object): pass
+ self.register(A, inits)
+
+ class B(A):
+ def __init__(self):
+ inits.append((B, '__init__'))
+ self.register(B, inits)
+
+ class C(B): pass
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A)])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B), (B, '__init__')])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (B, '__init__')])
+
+ def test_A_B_Ci(self):
+ inits = []
+
+ class A(object): pass
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ class C(B):
+ def __init__(self):
+ inits.append((C, '__init__'))
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A)])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B)])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C), (C, '__init__')])
+
+ def test_A_B_C(self):
+ inits = []
+
+ class A(object): pass
+ self.register(A, inits)
+
+ class B(A): pass
+ self.register(B, inits)
+
+ class C(B): pass
+ self.register(C, inits)
+
+ obj = A()
+ eq_(inits, [(A, 'on_init', A)])
+
+ del inits[:]
+
+ obj = B()
+ eq_(inits, [(B, 'on_init', B)])
+
+ del inits[:]
+ obj = C()
+ eq_(inits, [(C, 'on_init', C)])
+
+
+class MapperInitTest(TestBase):
+
+ def fixture(self):
+ return Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column('type', Integer),
+ Column('x', Integer),
+ Column('y', Integer))
+
+ def test_partially_mapped_inheritance(self):
+ class A(object):
+ pass
+
+ class B(A):
+ pass
+
+ class C(B):
+ def __init__(self):
+ pass
+
+ mapper(A, self.fixture())
+
+ a = attributes.instance_state(A())
+ assert isinstance(a, attributes.InstanceState)
+ assert type(a) is not attributes.InstanceState
+
+ b = attributes.instance_state(B())
+ assert isinstance(b, attributes.InstanceState)
+ assert type(b) is not attributes.InstanceState
+
+ # C is unmanaged
+ cobj = C()
+ self.assertRaises((AttributeError, TypeError),
+ attributes.instance_state, cobj)
+
+class InstrumentationCollisionTest(TestBase):
+ def test_none(self):
+ class A(object): pass
+ attributes.register_class(A)
+
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+ class B(object):
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+ attributes.register_class(B)
+
+ class C(object):
+ __sa_instrumentation_manager__ = attributes.ClassManager
+ attributes.register_class(C)
+
+ def test_single_down(self):
+ class A(object): pass
+ attributes.register_class(A)
+
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+ class B(A):
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+
+ self.assertRaises(TypeError, attributes.register_class, B)
+
+ def test_single_up(self):
+
+ class A(object): pass
+ # delay registration
+
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+ class B(A):
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+ attributes.register_class(B)
+ self.assertRaises(TypeError, attributes.register_class, A)
+
+ def test_diamond_b1(self):
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+ class A(object): pass
+ class B1(A): pass
+ class B2(A):
+ __sa_instrumentation_manager__ = mgr_factory
+ class C(object): pass
+
+ self.assertRaises(TypeError, attributes.register_class, B1)
+
+ def test_diamond_b2(self):
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+ class A(object): pass
+ class B1(A): pass
+ class B2(A):
+ __sa_instrumentation_manager__ = mgr_factory
+ class C(object): pass
+
+ self.assertRaises(TypeError, attributes.register_class, B2)
+
+ def test_diamond_c_b(self):
+ mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+ class A(object): pass
+ class B1(A): pass
+ class B2(A):
+ __sa_instrumentation_manager__ = mgr_factory
+ class C(object): pass
+
+ attributes.register_class(C)
+ self.assertRaises(TypeError, attributes.register_class, B1)
+
+
+class OnLoadTest(TestBase):
+ """Check that Events.on_load is not hit in regular attributes operations."""
+
+ def test_basic(self):
+ import pickle
+
+ global A
+ class A(object):
+ pass
+
+ def canary(instance): assert False
+
+ try:
+ attributes.register_class(A)
+ manager = attributes.manager_of_class(A)
+ manager.events.add_listener('on_load', canary)
+
+ a = A()
+ p_a = pickle.dumps(a)
+ re_a = pickle.loads(p_a)
+ finally:
+ del A
+
+
+class ExtendedEventsTest(TestBase):
+ """Allow custom Events implementations."""
+
+ @modifies_instrumentation_finders
+ def test_subclassed(self):
+ class MyEvents(attributes.Events):
+ pass
+ class MyClassManager(attributes.ClassManager):
+ event_registry_factory = MyEvents
+
+ attributes.instrumentation_finders.insert(0, lambda cls: MyClassManager)
+
+ class A(object): pass
+
+ attributes.register_class(A)
+ manager = attributes.manager_of_class(A)
+ assert isinstance(manager.events, MyEvents)
+
+
+class NativeInstrumentationTest(TestBase):
+ @with_lookup_strategy(util.symbol('native'))
+ def test_register_reserved_attribute(self):
+ class T(object): pass
+
+ attributes.register_class(T)
+ manager = attributes.manager_of_class(T)
+
+ sa = attributes.ClassManager.STATE_ATTR
+ ma = attributes.ClassManager.MANAGER_ATTR
+
+ fails = lambda method, attr: self.assertRaises(
+ KeyError, getattr(manager, method), attr, property())
+
+ fails('install_member', sa)
+ fails('install_member', ma)
+ fails('install_descriptor', sa)
+ fails('install_descriptor', ma)
+
+ @with_lookup_strategy(util.symbol('native'))
+ def test_mapped_stateattr(self):
+ t = Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column(attributes.ClassManager.STATE_ATTR, Integer))
+
+ class T(object): pass
+
+ self.assertRaises(KeyError, mapper, T, t)
+
+ @with_lookup_strategy(util.symbol('native'))
+ def test_mapped_managerattr(self):
+ t = Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column(attributes.ClassManager.MANAGER_ATTR, Integer))
+
+ class T(object): pass
+ self.assertRaises(KeyError, mapper, T, t)
+
+
+class MiscTest(TestBase):
+ """Seems basic, but not directly covered elsewhere!"""
+
+ def test_compileonattr(self):
+ t = Table('t', MetaData(),
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ class A(object): pass
+ mapper(A, t)
+
+ a = A()
+ assert a.id is None
+
+ def test_compileonattr_rel(self):
+ m = MetaData()
+ t1 = Table('t1', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ t2 = Table('t2', m,
+ Column('id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.id')))
+ class A(object): pass
+ class B(object): pass
+ mapper(A, t1, properties=dict(bs=relation(B)))
+ mapper(B, t2)
+
+ a = A()
+ assert not a.bs
+
+ def test_compileonattr_rel_backref_a(self):
+ m = MetaData()
+ t1 = Table('t1', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ t2 = Table('t2', m,
+ Column('id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.id')))
+
+ class Base(object):
+ def __init__(self, *args, **kwargs):
+ pass
+
+ for base in object, Base:
+ class A(base): pass
+ class B(base): pass
+ mapper(A, t1, properties=dict(bs=relation(B, backref='a')))
+ mapper(B, t2)
+
+ b = B()
+ assert b.a is None
+ a = A()
+ b.a = a
+
+ session = create_session()
+ session.save(b)
+ assert a in session, "base is %s" % base
+
+ def test_compileonattr_rel_backref_b(self):
+ m = MetaData()
+ t1 = Table('t1', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ t2 = Table('t2', m,
+ Column('id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.id')))
+
+ class Base(object):
+ def __init__(self): pass
+ class Base_AKW(object):
+ def __init__(self, *args, **kwargs): pass
+
+ for base in object, Base, Base_AKW:
+ class A(base): pass
+ class B(base): pass
+ mapper(A, t1)
+ mapper(B, t2, properties=dict(a=relation(A, backref='bs')))
+
+ a = A()
+ b = B()
+ b.a = a
+
+ session = create_session()
+ session.save(a)
+ assert b in session, 'base: %s' % base
+
+ def test_compileonattr_rel_entity_name(self):
+ m = MetaData()
+ t1 = Table('t1', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ t2 = Table('t2', m,
+ Column('id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.id')))
+ class A(object): pass
+ class B(object): pass
+ mapper(A, t1, properties=dict(bs=relation(B)), entity_name='x')
+ mapper(B, t2)
+
+ a = A()
+ assert not a.bs
+
+class FinderTest(TestBase):
+ def test_standard(self):
+ class A(object): pass
+
+ attributes.register_class(A)
+
+ eq_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+ def test_nativeext_interfaceexact(self):
+ class A(object):
+ __sa_instrumentation_manager__ = interfaces.InstrumentationManager
+
+ attributes.register_class(A)
+ ne_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+ def test_nativeext_submanager(self):
+ class Mine(attributes.ClassManager): pass
+ class A(object):
+ __sa_instrumentation_manager__ = Mine
+
+ attributes.register_class(A)
+ eq_(type(attributes.manager_of_class(A)), Mine)
+
+ @modifies_instrumentation_finders
+ def test_customfinder_greedy(self):
+ class Mine(attributes.ClassManager): pass
+ class A(object): pass
+ def find(cls):
+ return Mine
+
+ attributes.instrumentation_finders.insert(0, find)
+ attributes.register_class(A)
+ eq_(type(attributes.manager_of_class(A)), Mine)
+
+ @modifies_instrumentation_finders
+ def test_customfinder_pass(self):
+ class A(object): pass
+ def find(cls):
+ return None
+
+ attributes.instrumentation_finders.insert(0, find)
+ attributes.register_class(A)
+ eq_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+
+if __name__ == "__main__":
+ testenv.main()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
from query import QueryTest
import datetime
+from sqlalchemy.orm import attributes
class LazyTest(FixtureTest):
keep_mappers = False
q = sess.query(User)
assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all()
- @testing.uses_deprecated('SessionContext')
- def test_bindstosession(self):
- """test that lazy loaders use the mapper's contextual session if the parent instance
- is not in a session, and that an error is raised if no contextual session"""
-
- from sqlalchemy.ext.sessioncontext import SessionContext
- ctx = SessionContext(create_session)
- m = mapper(User, users, properties = dict(
- addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True)
- ), extension=ctx.mapper_extension)
- q = ctx.current.query(m)
- u = q.filter(users.c.id == 7).first()
- ctx.current.expunge(u)
- assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u
-
- clear_mappers()
+ def test_needs_parent(self):
+ """test the error raised when parent object is not bound."""
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), lazy=True)
})
- try:
- sess = create_session()
- q = sess.query(User)
- u = q.filter(users.c.id == 7).first()
- sess.expunge(u)
- assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u
- assert False
- except exceptions.InvalidRequestError, err:
- assert "not bound to a Session, and no contextual session" in str(err)
+ sess = create_session()
+ q = sess.query(User)
+ u = q.filter(users.c.id == 7).first()
+ sess.expunge(u)
+ self.assertRaises(sa_exc.InvalidRequestError, getattr, u, 'addresses')
def test_orderby(self):
mapper(User, users, properties = {
sess = create_session()
user = sess.query(User).get(7)
- assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True)
- assert not class_mapper(Address)._is_orphan(user.addresses[0])
+ assert getattr(User, 'addresses').hasparent(attributes.instance_state(user.addresses[0]), optimistic=True)
+ assert not class_mapper(Address)._is_orphan(attributes.instance_state(user.addresses[0]))
def test_limit(self):
u2 = users.alias('u2')
s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
print [key for key in s.c.keys()]
- l = q.filter(s.c.u2_id==User.c.id).distinct().all()
+ l = q.filter(s.c.u2_id==User.id).distinct().all()
assert fixtures.user_all_result == l
def test_one_to_many_scalar(self):
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
from testlib import *
-from sqlalchemy import exceptions
class Place(object):
'''represents a place'''
mapper(Transition, transition, properties={
'places':relation(Place, secondary=place_input, backref='transitions')
})
- try:
- compile_mappers()
- assert False
- except exceptions.ArgumentError, e:
- assert str(e) in [
- "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'",
- "Error creating backref 'places' on relation 'Place.transitions (Transition)': property of that name exists on mapper 'Mapper|Transition|transition'"
- ]
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Error creating backref", compile_mappers)
def testcircular(self):
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc as sa_exc, sql
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt
from testlib import *
from testlib import fixtures
from testlib.tables import *
properties={
'addresses':relation(Address, backref='email_address')
})
- self.assertRaises(exceptions.ArgumentError, compile_mappers)
+ self.assertRaises(sa_exc.ArgumentError, compile_mappers)
def test_prop_accessor(self):
mapper(User, users)
self.assertRaises(NotImplementedError, getattr, class_mapper(User), 'properties')
+ @testing.uses_deprecated(
+ 'Call to deprecated function _instance_key',
+ 'Call to deprecated function _sa_session_id',
+ 'Call to deprecated function _entity_name')
+ def test_legacy_accesors(self):
+ u1 = User()
+ assert not hasattr(u1, '_instance_key')
+ assert not hasattr(u1, '_sa_session_id')
+ assert not hasattr(u1, '_entity_name')
+
+ mapper(User, users)
+ u1 = User()
+ assert not hasattr(u1, '_instance_key')
+ assert not hasattr(u1, '_sa_session_id')
+ assert u1._entity_name is None
+
+ sess = create_session()
+ sess.save(u1)
+ assert not hasattr(u1, '_instance_key')
+ assert u1._sa_session_id == sess.hash_key
+ assert u1._entity_name is None
+
+ sess.flush()
+ assert u1._instance_key == class_mapper(u1).identity_key_from_instance(u1)
+ assert u1._sa_session_id == sess.hash_key
+ assert u1._entity_name is None
+ sess.delete(u1)
+ sess.flush()
+
def test_badcascade(self):
mapper(Address, addresses)
- self.assertRaises(exceptions.ArgumentError, relation, Address, cascade="fake, all, delete-orphan")
+ self.assertRaises(sa_exc.ArgumentError, relation, Address, cascade="fake, all, delete-orphan")
def test_columnprefix(self):
mapper(User, users, column_prefix='_', properties={
def test_no_pks(self):
s = select([users.c.user_name]).alias('foo')
- self.assertRaises(exceptions.ArgumentError, mapper, User, s)
-
+ self.assertRaises(sa_exc.ArgumentError, mapper, User, s)
+
def test_recompile_on_othermapper(self):
- """test the global '_new_mappers' flag such that a compile
+ """test the global '_new_mappers' flag such that a compile
trigger on an already-compiled mapper still triggers a check against all mappers."""
from sqlalchemy.orm import mapperlib
-
+
mapper(User, users)
compile_mappers()
assert mapperlib._new_mappers is False
-
- m = mapper(Address, addresses, properties={'user':relation(User, backref="addresses")})
-
- assert m._Mapper__props_init is False
+
+ m = mapper(Address, addresses, properties={
+ 'user': relation(User, backref="addresses")})
+
+ assert m.compiled is False
assert mapperlib._new_mappers is True
u = User()
assert User.addresses
assert mapperlib._new_mappers is False
-
+
def test_compileonsession(self):
m = mapper(User, users)
session = create_session()
def test_badconstructor(self):
"""test that if the construction of a mapped class fails, the instnace does not get placed in the session"""
class Foo(object):
- def __init__(self, one, two):
+ def __init__(self, one, two, _sa_session=None):
pass
mapper(Foo, users)
sess = create_session()
assert len(list(sess)) == 0
self.assertRaises(TypeError, Foo, 'one')
- @testing.uses_deprecated('SessionContext', 'SessionContextExt')
- def test_constructorexceptions(self):
+ def test_constructorexc(self):
"""test that exceptions raised in the mapped class are not masked by sa decorations"""
ex = AssertionError('oops')
sess = create_session()
class Foo(object):
- def __init__(self):
+ def __init__(self, **kw):
raise ex
mapper(Foo, users)
assert e is ex
clear_mappers()
- mapper(Foo, users, extension=SessionContextExt(SessionContext()))
+ mapper(Foo, users, extension=scoped_session(create_session).extension)
def bad_expunge(foo):
raise Exception("this exception should be stated as a warning")
Foo(_sa_session=sess)
assert False
except Exception, e:
- assert isinstance(e, exceptions.SAWarning)
+ assert isinstance(e, sa_exc.SAWarning), e
clear_mappers()
mapper(User, users, properties = {
'addresses' : relation(mapper(Address, addresses))
})
- assert (User.user_id==3).compare(users.c.user_id==3)
+ self.assertEquals((User.user_id==3).__str__(), (users.c.user_id==3).__str__())
clear_mappers()
m.add_property('uc_user_name2', comparable_property(
UCComparator, User.uc_user_name2))
- sess = create_session(transactional=True)
+ sess = create_session(autocommit=False)
assert sess.query(User).get(7)
u = sess.query(User).filter_by(user_name='jack').one()
'addresses':relation(Address)
}).compile()
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e)
def test_illegal_non_primary_2(self):
try:
mapper(User, users, non_primary=True)
assert False
- except exceptions.InvalidRequestError, e:
+ except sa_exc.InvalidRequestError, e:
assert "Configure a primary mapper first" in str(e)
def test_propfilters(self):
def assert_props(cls, want):
have = set([n for n in dir(cls) if not n.startswith('_')])
want = set(want)
- want.add('c')
self.assert_(have == want, repr(have) + " " + repr(want))
assert_props(Person, ['id', 'name', 'type'])
assert_props(Hoho, ['id', 'name', 'type'])
assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type'])
- @testing.uses_deprecated('//select_by', '//join_via', '//list')
- def test_recursive_select_by_deprecated(self):
- """test that no endless loop occurs when traversing for select_by"""
- m = mapper(User, users, properties={
- 'orders':relation(mapper(Order, orders), backref='user'),
- 'addresses':relation(mapper(Address, addresses), backref='user'),
- })
- q = create_session().query(m)
- q.select_by(email_address='foo')
-
def test_mappingtojoin(self):
"""test mapping to a join"""
usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
self.assert_result(l, User, user_result[0])
- @testing.uses_deprecated('//select')
- def test_customjoin_deprecated(self):
- """test that the from_obj parameter to query.select() can be used
- to totally replace the FROM parameters of the generated query."""
-
- m = mapper(User, users, properties={
- 'orders':relation(mapper(Order, orders, properties={
- 'items':relation(mapper(Item, orderitems))
- }))
- })
-
- q = create_session().query(m)
- l = q.select((orderitems.c.item_name=='item 4'), from_obj=[users.join(orders).join(orderitems)])
- self.assert_result(l, User, user_result[0])
-
def test_orderby(self):
"""test ordering at the mapper and query level"""
mapper(User, users)
q = create_session().query(User)
self.assert_(q.count()==3)
- self.assert_(q.count(users.c.user_id.in_([8,9]))==2)
-
- @testing.unsupported('firebird')
- @testing.uses_deprecated('//count_by', '//join_by', '//join_via')
- def test_count_by_deprecated(self):
- mapper(User, users)
- q = create_session().query(User)
- self.assert_(q.count_by(user_name='fred')==1)
+ self.assert_(q.filter(users.c.user_id.in_([8,9])).count()==2)
def test_manytomany_count(self):
mapper(Item, orderitems, properties = dict(
keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True),
))
q = create_session().query(Item)
- assert q.join('keywords').distinct().count(Keyword.c.name=="red") == 2
+ assert q.join('keywords').distinct().filter(Keyword.name=="red").count() == 2
def test_override(self):
# assert that overriding a column raises an error
'user_name' : relation(mapper(Address, addresses)),
}).compile()
self.assert_(False, "should have raised ArgumentError")
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
self.assert_(True)
clear_mappers()
self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
addr = sess.query(Address).filter_by(address_id=user_address_result[0]['addresses'][1][0]['address_id']).one()
- u = sess.query(User).filter_by(adname=addr).one()
- u2 = sess.query(User).filter_by(adlist=addr).one()
+ u = sess.query(User).filter(User.adname.contains(addr)).one()
+ u2 = sess.query(User).filter(User.adlist.contains(addr)).one()
assert u is u2
})
User.not_user_name
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Can't compile synonym '_user_name': no column on table 'users' named 'not_user_name'"
clear_mappers()
self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
self.assert_sql_count(testing.db, go, 1)
- @testing.uses_deprecated('//select_by')
- def test_extension_options(self):
- sess = create_session()
- class ext1(MapperExtension):
- def populate_instance(self, mapper, selectcontext, row, instance, **flags):
- """test options at the Mapper._instance level"""
- instance.TEST = "hello world"
- return EXT_CONTINUE
- mapper(User, users, extension=ext1(), properties={
- 'addresses':relation(mapper(Address, addresses), lazy=False)
- })
- class testext(MapperExtension):
- def select_by(self, *args, **kwargs):
- """test options at the Query level"""
- return "HI"
- def populate_instance(self, mapper, selectcontext, row, instance, **flags):
- """test options at the Mapper._instance level"""
- instance.TEST_2 = "also hello world"
- return EXT_CONTINUE
- l = sess.query(User).options(extension(testext())).select_by(x=5)
- assert l == "HI"
- l = sess.query(User).options(extension(testext())).get(7)
- assert l.user_id == 7
- assert l.TEST == "hello world"
- assert l.TEST_2 == "also hello world"
- assert not hasattr(l.addresses[0], 'TEST')
- assert not hasattr(l.addresses[0], 'TEST2')
def test_eageroptions(self):
"""tests that a lazy relation can be upgraded to an eager relation via the options method"""
sess.clear()
- self.assertRaisesMessage(exceptions.ArgumentError,
- r"Can't find entity Mapper\|Order\|orders in Query. Current list: \['Mapper\|User\|users'\]",
- sess.query(User).options, eagerload('items', Order)
+ self.assertRaisesMessage(sa_exc.ArgumentError,
+ r"Can't find entity Mapper\|Order\|orders in Query. Current list: \['Mapper\|User\|users'\]",
+ sess.query(User).options, eagerload(Order.items)
)
# eagerload "keywords" on items. it will lazy load "orders", then lazy load
def setUpAll(self):
tables.create()
- global methods, Ext
+ def tearDown(self):
+ clear_mappers()
+ tables.delete()
+ def tearDownAll(self):
+ tables.drop()
+
+ def extension(self):
methods = []
class Ext(MapperExtension):
+ def instrument_class(self, mapper, cls):
+ methods.append('instrument_class')
+ return EXT_CONTINUE
+
+ def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
+ methods.append('init_instance')
+ return EXT_CONTINUE
+
+ def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
+ methods.append('init_failed')
+ return EXT_CONTINUE
+
def load(self, query, *args, **kwargs):
methods.append('load')
return EXT_CONTINUE
methods.append('after_delete')
return EXT_CONTINUE
- def tearDown(self):
- clear_mappers()
- methods[:] = []
- tables.delete()
-
- def tearDownAll(self):
- tables.drop()
+ return Ext, methods
def test_basic(self):
"""test that common user-defined methods get called."""
+ Ext, methods = self.extension()
+
mapper(User, users, extension=Ext())
sess = create_session()
u = User()
sess.flush()
sess.delete(u)
sess.flush()
- self.assertEquals(methods,
- ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 'translate_row',
- 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete']
- )
+ self.assertEquals(methods,
+ ['instrument_class', 'init_instance', 'before_insert',
+ 'after_insert', 'load', 'translate_row', 'populate_instance',
+ 'append_result', 'get', 'translate_row', 'create_instance',
+ 'populate_instance', 'append_result', 'before_update',
+ 'after_update', 'before_delete', 'after_delete'])
+
def test_inheritance(self):
- # test using inheritance
+ Ext, methods = self.extension()
+
class AdminUser(User):
pass
sess.flush()
sess.delete(am)
sess.flush()
- self.assertEquals(methods,
- ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get',
- 'translate_row', 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete'])
+ self.assertEquals(methods,
+ ['instrument_class', 'instrument_class', 'init_instance',
+ 'before_insert', 'after_insert', 'load', 'translate_row',
+ 'populate_instance', 'append_result', 'get', 'translate_row',
+ 'create_instance', 'populate_instance', 'append_result',
+ 'before_update', 'after_update', 'before_delete', 'after_delete'])
def test_after_with_no_changes(self):
# test that after_update is called even if no cols were updated
+ Ext, methods = self.extension()
+
mapper(Item, orderitems, extension=Ext() , properties={
'keywords':relation(Keyword, secondary=itemkeywords)
})
sess.save(i1)
sess.save(k1)
sess.flush()
- self.assertEquals(methods, ['before_insert', 'after_insert', 'before_insert', 'after_insert'])
+ self.assertEquals(methods,
+ ['instrument_class', 'instrument_class', 'init_instance',
+ 'init_instance', 'before_insert', 'after_insert',
+ 'before_insert', 'after_insert'])
- methods[:] = []
+ del methods[:]
i1.keywords.append(k1)
sess.flush()
self.assertEquals(methods, ['before_update', 'after_update'])
def test_inheritance_with_dupes(self):
+ Ext, methods = self.extension()
+
# test using inheritance, same extension on both mappers
class AdminUser(User):
pass
sess.flush()
sess.delete(am)
sess.flush()
- self.assertEquals(methods,
- ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 'translate_row',
- 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete']
- )
+ self.assertEquals(methods,
+ ['instrument_class', 'instrument_class', 'init_instance',
+ 'before_insert', 'after_insert', 'load', 'translate_row',
+ 'populate_instance', 'append_result', 'get', 'translate_row',
+ 'create_instance', 'populate_instance', 'append_result',
+ 'before_update', 'after_update', 'before_delete', 'after_delete'])
+
+ def test_single_instrumentor(self):
+ ext_None, methods_None = self.extension()
+ ext_x, methods_x = self.extension()
+
+ def reset():
+ clear_mappers()
+ del methods_None[:]
+ del methods_x[:]
+
+ mapper(User, users, extension=ext_None())
+ mapper(User, users, extension=ext_x(), entity_name='x')
+ User()
+
+ self.assertEquals(methods_None, ['instrument_class', 'init_instance'])
+ self.assertEquals(methods_x, [])
+
+ reset()
+
+ mapper(User, users, extension=ext_x(), entity_name='x')
+ mapper(User, users, extension=ext_None())
+ User()
+
+ self.assertEquals(methods_x, ['instrument_class', 'init_instance'])
+ self.assertEquals(methods_None, [])
+
+ reset()
+
+ ext_y, methods_y = self.extension()
+
+ mapper(User, users, extension=ext_x(), entity_name='x')
+ mapper(User, users, extension=ext_y(), entity_name='y')
+ User()
+
+ self.assertEquals(methods_x, ['instrument_class', 'init_instance'])
+ self.assertEquals(methods_y, [])
+
class RequirementsTest(ORMTest):
"""Tests the contract for user classes."""
class OldStyle:
pass
- self.assertRaises(exceptions.ArgumentError, mapper, OldStyle, t1)
+ self.assertRaises(sa_exc.ArgumentError, mapper, OldStyle, t1)
class NoWeakrefSupport(str):
pass
# TODO: is weakref support detectable without an instance?
- #self.assertRaises(exceptions.ArgumentError, mapper, NoWeakrefSupport, t2)
+ #self.assertRaises(sa_exc.ArgumentError, mapper, NoWeakrefSupport, t2)
def test_comparison_overrides(self):
"""Simple tests to ensure users can supply comparison __methods__.
return self.value == other.value
return False
-
mapper(H1, t1, properties={
'h2s': relation(H2, backref='h1'),
'h3s': relation(H3, secondary=t4, backref='h1s'),
def __ne__(self, other):
raise NotImplementedError()
+class MagicNamesTest(ORMTest):
+
+ def define_tables(self, metadata):
+ Table('cartographers', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('alias', String(50)),
+ Column('quip', String(100)))
+ Table('maps', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('cart_id', Integer,
+ ForeignKey('cartographers.id')),
+ Column('state', String(2)),
+ Column('data', Text))
+
+ def tables(self):
+ cat = testing._otest_metadata.tables
+ return cat['cartographers'], cat['maps']
+
+ def classes(self):
+ class Base(object):
+ def __init__(self, **kw):
+ for key, value in kw.iteritems():
+ setattr(self, key, value)
+ class Cartographer(Base): pass
+ class Map(Base): pass
+
+ return Cartographer, Map
+
+ @testing.future
+ def test_mappish(self):
+ t1, t2 = self.tables()
+ Cartographer, Map = self.classes()
+ mapper(Cartographer, t1, properties=dict(
+ query=t1.c.quip))
+ mapper(Map, t2, properties=dict(
+ mapper=relation(Cartographer, backref='maps')))
+
+ c = Cartographer(name='Lenny', alias='The Dude',
+ query='Where be dragons?')
+ m = Map(state='AK', mapper=c)
+
+ sess = create_session()
+ sess.save(c)
+ sess.flush()
+ sess.clear()
+
+ for C, M in ((Cartographer, Map), (aliased(Cartographer), aliased(Map))):
+ print C, M
+ c1 = (sess.query(C).
+ filter(C.alias=='The Dude').
+ filter(C.query=='Where be dragons?')).one()
+ m1 = sess.query(M).filter(M.mapper==c1).one()
+
+ @testing.future
+ def test_stateish(self):
+ from sqlalchemy.orm import attributes
+ if hasattr(attributes, 'ClassManager'):
+ syn1 = attributes.ClassManager.STATE_ATTR
+ syn2 = attributes.ClassManager.MANAGER_ATTR
+ else:
+ syn1 = '_state'
+ syn2 = '_class_state'
+
+
+ t1, t2 = self.tables()
+ Cartographer, Map = self.classes()
+ mapper(Map, t2, properties=dict(
+ syn1=t2.c.state,
+ syn2=t2.c.data))
+
+ m = Map()
+ setattr(m, syn1, 'AK')
+ setattr(m, syn2, '10x10')
+
+ sess = create_session()
+ sess.save(m)
+ sess.flush()
+ sess.clear()
+
+ for M in (Map, aliased(Map)):
+ print M
+ sess.query(M).filter(getattr(M, syn1) == 'AK').one()
+ sess.query(M).filter(getattr(M, syn2) == '10x10').one()
+
+
class ScalarRequirementsTest(ORMTest):
def define_tables(self, metadata):
import pickle
t1 = Table('t1', metadata, Column('id', Integer, primary_key=True),
Column('data', PickleType(pickler=pickle)) # dont use cPickle due to import weirdness
)
-
+
def test_correct_comparison(self):
-
+
class H1(fixtures.Base):
pass
-
+
mapper(H1, t1)
-
+
h1 = H1(data=NoEqFoo('12345'))
s = create_session()
s.save(h1)
s.clear()
h1 = s.get(H1, h1.id)
assert h1.data.data == '12345'
-
+
if __name__ == "__main__":
testenv.main()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
-from sqlalchemy.orm import mapperlib
+from sqlalchemy.orm import mapperlib, attributes
from sqlalchemy.util import OrderedSet
from testlib import *
from testlib import fixtures
clear_mappers()
tables.delete()
+ def on_load_tracker(self, cls, canary=None):
+ if canary is None:
+ def canary(instance):
+ canary.called += 1
+ canary.called = 0
+
+ manager = attributes.manager_of_class(cls)
+ manager.events.add_listener('on_load', canary)
+
+ return canary
+
def test_transient_to_pending(self):
class User(fixtures.Base):
pass
mapper(User, users)
sess = create_session()
+ on_load = self.on_load_tracker(User)
u = User(user_id=7, user_name='fred')
+ assert on_load.called == 0
u2 = sess.merge(u)
+ assert on_load.called == 1
assert u2 in sess
self.assertEquals(u2, User(user_id=7, user_name='fred'))
sess.flush()
sess.clear()
self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred'))
-
+
def test_transient_to_pending_collection(self):
class User(fixtures.Base):
pass
pass
mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
mapper(Address, addresses)
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
u = User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=1, email_address='fred1'),
Address(address_id=2, email_address='fred2'),
- ]))
+ ]))
+ assert on_load.called == 0
+
sess = create_session()
sess.merge(u)
+ assert on_load.called == 3
+
+ merged_users = [e for e in sess if isinstance(e, User)]
+ assert len(merged_users) == 1
+ assert merged_users[0] is not u
+
sess.flush()
sess.clear()
- self.assertEquals(sess.query(User).one(),
+ self.assertEquals(sess.query(User).one(),
User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=1, email_address='fred1'),
Address(address_id=2, email_address='fred2'),
]))
)
-
+
def test_transient_to_persistent(self):
class User(fixtures.Base):
pass
mapper(User, users)
+ on_load = self.on_load_tracker(User)
+
sess = create_session()
u = User(user_id=7, user_name='fred')
sess.save(u)
sess.flush()
sess.clear()
-
- u2 = User(user_id=7, user_name='fred jones')
+
+ assert on_load.called == 0
+
+ _u2 = u2 = User(user_id=7, user_name='fred jones')
+ assert on_load.called == 0
u2 = sess.merge(u2)
+ assert u2 is not _u2
+ assert on_load.called == 1
sess.flush()
sess.clear()
self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred jones'))
-
+ assert on_load.called == 2
+
def test_transient_to_persistent_collection(self):
class User(fixtures.Base):
pass
class Address(fixtures.Base):
pass
- mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+ mapper(User, users, properties={
+ 'addresses':relation(Address,
+ backref='user',
+ collection_class=OrderedSet, cascade="all, delete-orphan")
+ })
mapper(Address, addresses)
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
+
u = User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=1, email_address='fred1'),
Address(address_id=2, email_address='fred2'),
sess.save(u)
sess.flush()
sess.clear()
-
+
+ assert on_load.called == 0
+
u = User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=3, email_address='fred3'),
Address(address_id=4, email_address='fred4'),
]))
-
+
u = sess.merge(u)
- self.assertEquals(u,
+
+ assert on_load.called == 5, on_load.called # 1. merges User object. updates into session.
+ # 2.,3. merges Address ids 3 & 4, saves into session.
+ # 4.,5. loads pre-existing elements in "addresses" collection,
+ # marks as deleted, Address ids 1 and 2.
+ self.assertEquals(u,
User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=3, email_address='fred3'),
Address(address_id=4, email_address='fred4'),
)
sess.flush()
sess.clear()
- self.assertEquals(sess.query(User).one(),
+ self.assertEquals(sess.query(User).one(),
User(user_id=7, user_name='fred', addresses=OrderedSet([
Address(address_id=3, email_address='fred3'),
Address(address_id=4, email_address='fred4'),
]))
)
-
+
def test_detached_to_persistent_collection(self):
class User(fixtures.Base):
pass
pass
mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
mapper(Address, addresses)
-
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
+
a = Address(address_id=1, email_address='fred1')
u = User(user_id=7, user_name='fred', addresses=OrderedSet([
a,
sess.save(u)
sess.flush()
sess.clear()
-
+
u.user_name='fred jones'
u.addresses.add(Address(address_id=3, email_address='fred3'))
u.addresses.remove(a)
-
+
+ assert on_load.called == 0
u = sess.merge(u)
+ assert on_load.called == 4
sess.flush()
sess.clear()
-
- self.assertEquals(sess.query(User).first(),
+
+ self.assertEquals(sess.query(User).first(),
User(user_id=7, user_name='fred jones', addresses=OrderedSet([
Address(address_id=2, email_address='fred2'),
Address(address_id=3, email_address='fred3'),
]))
)
-
+
def test_unsaved_cascade(self):
"""test merge of a transient entity with two child transient entities, with a bidirectional relation."""
-
+
class User(fixtures.Base):
pass
class Address(fixtures.Base):
pass
-
+
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), cascade="all", backref="user")
})
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
sess = create_session()
+
u = User(user_id=7, user_name='fred')
a1 = Address(email_address='foo@bar.com')
a2 = Address(email_address='hoho@bar.com')
u.addresses.append(a2)
u2 = sess.merge(u)
+ assert on_load.called == 3
+
self.assertEquals(u, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
sess.flush()
sess.clear()
u2 = sess.query(User).get(7)
self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
+ assert on_load.called == 6
+
def test_attribute_cascade(self):
"""test merge of a persistent entity with two child persistent entities."""
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), backref='user')
})
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
+
sess = create_session()
# set up data and save
u.user_name = 'fred2'
u.addresses[1].email_address = 'hoho@lalala.com'
+ assert on_load.called == 3
+
# new session, merge modified data into session
sess3 = create_session()
u3 = sess3.merge(u)
+ assert on_load.called == 6
# ensure local changes are pending
self.assertEquals(u3, User(user_id=7, user_name='fred2', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@lalala.com')]))
sess.clear()
u = sess.query(User).get(7)
self.assertEquals(u, User(user_id=7, user_name='fred2', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@lalala.com')]))
+ assert on_load.called == 9
# merge persistent object into another session
sess4 = create_session()
sess4.flush()
# no changes; therefore flush should do nothing
self.assert_sql_count(testing.db, go, 0)
+ assert on_load.called == 12
# test with "dontload" merge
sess5 = create_session()
# but also, dont_load wipes out any difference in committed state,
# so no flush at all
self.assert_sql_count(testing.db, go, 0)
+ assert on_load.called == 15
sess4 = create_session()
u = sess4.merge(u, dont_load=True)
sess4.flush()
# afafds change flushes
self.assert_sql_count(testing.db, go, 1)
+ assert on_load.called == 18
sess5 = create_session()
u2 = sess5.query(User).get(u.user_id)
assert u2.user_name == 'fred2'
assert u2.addresses[1].email_address == 'afafds'
+ assert on_load.called == 21
def test_one_to_many_cascade(self):
'addresses':relation(mapper(Address, addresses)),
'orders':relation(Order, backref='customer')
})
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
+ self.on_load_tracker(Order, on_load)
sess = create_session()
u = User()
sess.save(u)
sess.flush()
+ assert on_load.called == 0
+
sess2 = create_session()
u2 = sess2.query(User).get(u.user_id)
+ assert on_load.called == 1
+
u.orders[0].items[1].item_name = 'item 2 modified'
sess2.merge(u)
assert u2.orders[0].items[1].item_name == 'item 2 modified'
+ assert on_load.called == 2
+
+ sess3 = create_session()
+ o2 = sess3.query(Order).get(o.order_id)
+ assert on_load.called == 3
- sess2 = create_session()
- o2 = sess2.query(Order).get(o.order_id)
o.customer.user_name = 'also fred'
- sess2.merge(o)
+ sess3.merge(o)
+ assert on_load.called == 4
assert o2.customer.user_name == 'also fred'
def test_one_to_one_cascade(self):
mapper(User, users, properties={
'address':relation(mapper(Address, addresses),uselist = False)
})
+ on_load = self.on_load_tracker(User)
+ self.on_load_tracker(Address, on_load)
sess = create_session()
+
u = User()
u.user_id = 7
u.user_name = "fred"
sess.save(u)
sess.flush()
+ assert on_load.called == 0
+
sess2 = create_session()
u2 = sess2.query(User).get(7)
+ assert on_load.called == 1
u2.user_name = 'fred2'
u2.address.email_address = 'hoho@lalala.com'
+ assert on_load.called == 2
u3 = sess.merge(u2)
-
+ assert on_load.called == 2
+ assert u3 is u
+
def test_transient_dontload(self):
mapper(User, users)
sess = create_session()
u = User()
- self.assertRaisesMessage(exceptions.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
def test_dontload_with_backrefs(self):
try:
sess2.merge(u, dont_load=True)
assert False
- except exceptions.InvalidRequestError, e:
+ except sa_exc.InvalidRequestError, e:
assert "merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True." in str(e)
u2 = sess2.query(User).get(7)
u2 = sess2.merge(u, dont_load=True)
assert not sess2.dirty
# assert merged instance has a mapper and lazy load proceeds
- assert hasattr(u2, '_entity_name')
+ state = attributes.instance_state(u2)
+ assert state.entity_name is not attributes.NO_ENTITY_NAME
assert mapperlib.has_mapper(u2)
def go():
assert u2.addresses != []
assert not sess2.dirty
a2 = u2.addresses[0]
a2.email_address='somenewaddress'
- assert not object_mapper(a2)._is_orphan(a2)
+ assert not object_mapper(a2)._is_orphan(attributes.instance_state(a2))
sess2.flush()
sess2.clear()
assert sess2.query(User).get(u2.user_id).addresses[0].email_address == 'somenewaddress'
# if dont_load is changed to support dirty objects, this code needs to pass
a2 = u2.addresses[0]
a2.email_address='somenewaddress'
- assert not object_mapper(a2)._is_orphan(a2)
+ assert not object_mapper(a2)._is_orphan(attributes.instance_state(a2))
sess2.flush()
sess2.clear()
assert sess2.query(User).get(u2.user_id).addresses[0].email_address == 'somenewaddress'
- except exceptions.InvalidRequestError, e:
+ except sa_exc.InvalidRequestError, e:
assert "dont_load=True option does not support" in str(e)
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exceptions
-
+from sqlalchemy.orm import attributes, exc as orm_exc
from testlib.fixtures import *
from testlib import *
sess.flush()
assert sess.get(User, 'jack') is u1
- users.update(values={u1.c.username:'jack'}).execute(username='ed')
+ users.update(values={User.username:'jack'}).execute(username='ed')
- try:
- # expire/refresh works off of primary key. the PK is gone
- # in this case so theres no way to look it up. criterion-
- # based session invalidation could solve this [ticket:911]
- sess.expire(u1)
- u1.username
- assert False
- except exceptions.InvalidRequestError, e:
- assert "Could not refresh instance" in str(e)
+ # expire/refresh works off of primary key. the PK is gone
+ # in this case so theres no way to look it up. criterion-
+ # based session invalidation could solve this [ticket:911]
+ sess.expire(u1)
+ self.assertRaises(orm_exc.ObjectDeletedError, getattr, u1, 'username')
sess.clear()
assert sess.get(User, 'jack') is None
u1.username = 'ed'
print id(a1), id(a2), id(u1)
- print u1._state.parents
+ print attributes.instance_state(u1).parents
def go():
sess.flush()
if passive_updates:
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
from testlib import *
class Jack(object):
def setUpAll(self):
global jack, port, metadata, ctx
metadata = MetaData(testing.db)
- ctx = SessionContext(create_session)
+ ctx = scoped_session(create_session)
jack = Table('jack', metadata,
Column('id', Integer, primary_key=True),
#Column('room_id', Integer, ForeignKey("room.id")),
def tearDownAll(self):
metadata.drop_all()
- @testing.uses_deprecated('SessionContext')
def test1(self):
- mapper(Port, port, extension=ctx.mapper_extension)
+ mapper(Port, port, extension=ctx.extension)
mapper(Jack, jack, order_by=[jack.c.number],properties = {
'port': relation(Port, backref='jack', uselist=False, lazy=True),
- }, extension=ctx.mapper_extension)
+ }, extension=ctx.extension)
j=Jack(number='101')
p=Port(name='fa0/1')
j.port=p
- ctx.current.flush()
+ ctx.flush()
jid = j.id
pid = p.id
- j=ctx.current.query(Jack).get(jid)
- p=ctx.current.query(Port).get(pid)
+ j=ctx.query(Jack).get(jid)
+ p=ctx.query(Port).get(pid)
print p.jack
assert p.jack is not None
assert p.jack is j
p.jack=None
assert j.port is None #works
- ctx.current.clear()
+ ctx.clear()
- j=ctx.current.query(Jack).get(jid)
- p=ctx.current.query(Port).get(pid)
+ j=ctx.query(Jack).get(jid)
+ p=ctx.query(Port).get(pid)
j.port=None
self.assert_(p.jack is None)
- ctx.current.flush()
+ ctx.flush()
- ctx.current.delete(j)
- ctx.current.flush()
+ ctx.delete(j)
+ ctx.flush()
if __name__ == "__main__":
testenv.main()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
)
def test_polymorphic_deferred(self):
- mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type, polymorphic_fetch='deferred')
+ mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type)
mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser')
eu = EmailUser(name="user1", email_address='foo@bar.com')
import testenv; testenv.configure_for_tests()
import operator
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
from sqlalchemy.sql import compiler
from sqlalchemy.engine import default
from sqlalchemy.orm import *
from testlib import engines
from testlib.fixtures import *
-from sqlalchemy.orm.util import _join as join, _outerjoin as outerjoin
+from sqlalchemy.orm.util import join, outerjoin, with_parent
class QueryTest(FixtureTest):
keep_mappers = True
keep_data = True
+
def setup_mappers(self):
mapper(User, users, properties={
'addresses':relation(Address, backref='user'),
s = create_session()
- try:
- s.query(User).join('addresses').filter(Address.user_id==8).get(7)
- assert False
- except exceptions.SAWarning, e:
- assert str(e) == "Query.get() being called on a Query with existing criterion; criterion is being ignored."
+ q = s.query(User).join('addresses').filter(Address.user_id==8)
+ self.assertRaises(sa_exc.SAWarning, q.get, 7)
@testing.emits_warning('Query.*')
def warns():
try:
assert s.query(User).load(19) is None
assert False
- except exceptions.InvalidRequestError:
+ except sa_exc.InvalidRequestError:
assert True
u = s.query(User).load(7)
assert u.addresses[0].email_address == 'jack@bean.com'
assert u.orders[1].items[2].description == 'item 5'
+class InvalidGenerationsTest(QueryTest):
+ def test_no_limit_offset(self):
+ s = create_session()
+
+ q = s.query(User).limit(2)
+ self.assertRaises(sa_exc.SAWarning, q.join, "addresses")
+
+ self.assertRaises(sa_exc.SAWarning, q.filter, User.name=='ed')
+
+ self.assertRaises(sa_exc.SAWarning, q.filter_by, name='ed')
+
+ def test_no_from(self):
+ s = create_session()
+
+ q = s.query(User).select_from(users)
+ self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+
+ q = s.query(User).join('addresses')
+ self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+
+ # this is fine, however
+ q.from_self()
+
class OperatorTest(QueryTest):
"""test sql.Comparator implementation for MapperProperties"""
c = expr.compile(dialect=default.DefaultDialect())
assert str(c) == compare, "%s != %s" % (str(c), compare)
+class RawSelectTest(QueryTest, AssertsCompiledSQL):
+ """compare a bunch of select() tests with the equivalent Query using straight table/columns.
+
+ Results should be the same as Query should act as a select() pass-thru for ClauseElement entities.
+
+ """
+ def test_select(self):
+ sess = create_session()
+
+ self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement,
+ "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1")
+
+ self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement,
+ "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users")
+ # a little tedious here, adding labels to work around Query's auto-labelling.
+ # also correlate needed explicitly. hmmm.....
+ # TODO: can we detect only one table in the "froms" and then turn off use_labels ?
+ s = sess.query(addresses.c.id.label('id'), addresses.c.email_address.label('email')).\
+ filter(addresses.c.user_id==users.c.id).correlate(users).statement.alias()
+
+ self.assert_compile(sess.query(users, s.c.email).select_from(users.join(s, s.c.id==users.c.id)).with_labels().statement,
+ "SELECT users.id AS users_id, users.name AS users_name, anon_1.email AS anon_1_email "
+ "FROM users JOIN (SELECT addresses.id AS id, addresses.email_address AS email FROM addresses "
+ "WHERE addresses.user_id = users.id) AS anon_1 ON anon_1.id = users.id",
+ dialect=default.DefaultDialect()
+ )
+
+ x = func.lala(users.c.id).label('foo')
+ self.assert_compile(sess.query(x).filter(x==5).statement,
+ "SELECT lala(users.id) AS foo FROM users WHERE lala(users.id) = :param_1", dialect=default.DefaultDialect())
+
class CompileTest(QueryTest):
+
def test_deferred(self):
session = create_session()
s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).compile()
try:
sess.query(User).filter(User.addresses == address)
assert False
- except exceptions.InvalidRequestError:
+ except sa_exc.InvalidRequestError:
assert True
assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all()
try:
assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
assert False
- except exceptions.InvalidRequestError:
+ except sa_exc.InvalidRequestError:
assert True
#assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
filter(User.addresses.any(id=4)).all()
assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all()
-
- @testing.fails_on_everything_except()
- def test_broken_any_1(self):
- sess = create_session()
- # overcorrelates
+ # test that any() doesn't overcorrelate
assert [User(id=7), User(id=8)] == sess.query(User).join("addresses").filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all()
-
- def test_broken_any_2(self):
- sess = create_session()
- # works, filter is before the join
- assert [User(id=7), User(id=8)] == sess.query(User).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).join("addresses", aliased=True).all()
-
- def test_broken_any_3(self):
- sess = create_session()
-
- # works, filter is after the join, but reset_joinpoint is called, removing aliasing
- assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(Address.email_address != None).reset_joinpoint().filter(~User.addresses.any(email_address='fred@fred.com')).all()
+ # test that the contents are not adapted by the aliased join
+ assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all()
- @testing.fails_on_everything_except()
- def test_broken_any_4(self):
- sess = create_session()
-
- # filter is after the join, gets aliased. in 0.5 any(), has() and not contains() are shielded from aliasing
assert [User(id=10)] == sess.query(User).outerjoin("addresses", aliased=True).filter(~User.addresses.any()).all()
-
+
@testing.unsupported('maxdb') # can core
def test_has(self):
sess = create_session()
assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+ # test has() doesn't overcorrelate
+ assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+
+ # test has() doesnt' get subquery contents adapted by aliased join
+ assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+
dingaling = sess.query(Dingaling).get(2)
assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all()
(User(id=8), Address(id=4)),
(User(id=9), Address(id=5))
] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().join('addresses').add_entity(Address).order_by(User.id, Address.id).all()
+
+ def test_multiple_entities(self):
+ sess = create_session()
+
+ self.assertEquals(
+ sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(),
+ [
+ (User(id=8), Address(id=2)),
+ (User(id=9), Address(id=5))
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(),
+ (User(id=8, addresses=[Address(), Address(), Address()]), Address(id=2)),
+ )
class AggregateTest(QueryTest):
+
def test_sum(self):
sess = create_session()
orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
assert orders.sum(Order.user_id * Order.address_id) == 79
- @testing.uses_deprecated('Call to deprecated function apply_sum')
def test_apply(self):
sess = create_session()
- assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_([2, 3, 4])).one() == 79
+ assert sess.query(func.sum(Order.user_id * Order.address_id)).filter(Order.id.in_([2, 3, 4])).one() == (79,)
def test_having(self):
sess = create_session()
- assert [User(name=u'ed',id=8)] == sess.query(User).group_by([c for c in User.c]).join('addresses').having(func.count(Address.c.id)> 2).all()
+ assert [User(name=u'ed',id=8)] == sess.query(User).group_by(User).join('addresses').having(func.count(Address.id)> 2).all()
- assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).group_by([c for c in User.c]).join('addresses').having(func.count(Address.c.id)< 2).all()
+ assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).group_by(User).join('addresses').having(func.count(Address.id)< 2).all()
class CountTest(QueryTest):
def test_basic(self):
o = sess.query(Order).with_parent(u1, property='orders').all()
assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
- # test static method
- o = Query.query_from_parent(u1, property='orders', session=sess).all()
+ o = sess.query(Order).filter(with_parent(u1, User.orders)).all()
assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
-
+
+ # test static method
+ @testing.uses_deprecated(".*query_from_parent")
+ def go():
+ o = Query.query_from_parent(u1, property='orders', session=sess).all()
+ assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
+ go()
+
# test generative criterion
o = sess.query(Order).with_parent(u1).filter(orders.c.id>2).all()
assert [Order(description="order 3"), Order(description="order 5")] == o
try:
q = sess.query(Item).with_parent(u1)
assert False
- except exceptions.InvalidRequestError, e:
+ except sa_exc.InvalidRequestError, e:
assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'"
def test_m2m(self):
class JoinTest(QueryTest):
- def test_getjoinable_tables(self):
- sess = create_session()
-
- sel1 = select([users]).alias()
- sel2 = select([users], from_obj=users.join(addresses)).alias()
-
- j1 = sel1.join(users, sel1.c.id==users.c.id)
- j2 = j1.join(addresses)
-
- for from_obj, assert_cond in (
- (users, [users]),
- (users.join(addresses), [users, addresses]),
- (sel1, [sel1]),
- (sel2, [sel2]),
- (sel1.join(users, sel1.c.id==users.c.id), [sel1, users]),
- (sel2.join(users, sel2.c.id==users.c.id), [sel2, users]),
- (j2, [j1, j2, sel1, users, addresses])
-
- ):
- ret = set(sess.query(User).select_from(from_obj)._get_joinable_tables())
- self.assertEquals(ret, set(assert_cond).union([from_obj]), [x.description for x in ret])
-
def test_overlapping_paths(self):
for aliased in (True,False):
# load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
def test_orderby_arg_bug(self):
sess = create_session()
+ # no arg error
+ result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all()
+
+ def test_no_onclause(self):
+ sess = create_session()
+
+ self.assertEquals(
+ sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(),
+ [User(name='jack')]
+ )
+
+ self.assertEquals(
+ sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(),
+ [User(name='jack')]
+ )
+ def test_clause_onclause(self):
+ sess = create_session()
+
+ self.assertEquals(
+ sess.query(User).join(
+ (Order, User.id==Order.user_id),
+ (order_items, Order.id==order_items.c.order_id),
+ (Item, order_items.c.item_id==Item.id)
+ ).filter(Item.description == 'item 4').all(),
+ [User(name='jack')]
+ )
+
# no arg error
result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all()
l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
self.assertEquals(l, [(user8, address3)])
-
l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all()
self.assertEquals(l, [(user8, address3)])
l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
self.assertEquals(l, [(user8, address3)])
+ # this is the first test where we are joining "backwards" - from AdAlias to User even though
+ # the query is against User
+ q = sess.query(User, AdAlias)
+ l = q.join(AdAlias.user).filter(User.name=='ed')
+ self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+
+ q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed')
+ self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+
+ def test_implicit_joins_from_aliases(self):
+ sess = create_session()
+ OrderAlias = aliased(Order)
+
+ self.assertEquals(
+ sess.query(OrderAlias).join('items').filter_by(description='item 3').all(),
+ [
+ Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1),
+ Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2),
+ Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3)
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').all(),
+ [
+ (User(name=u'jack',id=7), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), u'item 3'),
+ (User(name=u'jack',id=7), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), u'item 3'),
+ (User(name=u'fred',id=9), Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2), u'item 3')
+ ]
+ )
+
def test_aliased_classes_m2m(self):
sess = create_session()
]
)
- def test_generative_join(self):
- # test that alised_ids is copied
- sess = create_session()
- q = sess.query(User).add_entity(Address)
- q1 = q.join('addresses', aliased=True)
- q2 = q.join('addresses', aliased=True)
- q3 = q2.join('addresses', aliased=True)
- q4 = q2.join('addresses', aliased=True, id='someid')
- q5 = q2.join('addresses', aliased=True, id='someid')
- q6 = q5.join('addresses', aliased=True, id='someid')
- assert q1._alias_ids[class_mapper(Address)] != q2._alias_ids[class_mapper(Address)]
- assert q2._alias_ids[class_mapper(Address)] != q3._alias_ids[class_mapper(Address)]
- assert q4._alias_ids['someid'] != q5._alias_ids['someid']
-
def test_reset_joinpoint(self):
for aliased in (True, False):
# load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
assert q.count() == 1
assert [User(id=7)] == q.all()
+
# test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1
- # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the
- # join clauses that are cached by the relationship.
- q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Order.description=="item 1")
+ q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Item.description=="item 1")
assert [] == q.all()
assert q.count() == 0
q = sess.query(User).join('orders', aliased=True).filter(Order.items.any(Item.description=='item 4'))
assert [User(id=7)] == q.all()
-
- def test_aliased_add_entity(self):
- """test the usage of aliased joins with add_entity()"""
- sess = create_session()
- q = sess.query(User).join('orders', aliased=True, id='order1').filter(Order.description=="order 3").join(['orders', 'items'], aliased=True, id='item1').filter(Item.description=="item 1")
-
- try:
- q.add_entity(Order, id='fakeid').compile()
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Query has no alias identified by 'fakeid'"
-
- try:
- q.add_entity(Order, id='fakeid').instances(None)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Query has no alias identified by 'fakeid'"
-
- q = q.add_entity(Order, id='order1').add_entity(Item, id='item1')
+
+ # test that aliasing gets reset when join() is called
+ q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=="order 5")
assert q.count() == 1
- assert [(User(id=7), Order(description='order 3'), Item(description='item 1'))] == q.all()
-
- q = sess.query(User).add_entity(Order).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=='order 4')
- try:
- q.compile()
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Ambiguous join for entity 'Mapper|Order|orders'; specify id=<someid> to query.join()/query.add_entity()"
+ assert [User(id=7)] == q.all()
class MultiplePathTest(ORMTest):
def define_tables(self, metadata):
})
mapper(T2, t2)
- try:
- create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2')
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Can't join to property 't2s_2'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`."
+ q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint()
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.",
+ q.join, 't2s_2'
+ )
create_session().query(T1).join('t2s_1', aliased=True).filter(t2.c.id==5).reset_joinpoint().join('t2s_2').all()
create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2', aliased=True).all()
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
+ # better way. use select_from()
+ def go():
+ l = sess.query(User).select_from(query).options(contains_eager('addresses')).all()
+ assert fixtures.user_address_result == l
+ self.assert_sql_count(testing.db, go, 1)
+
def test_contains_eager(self):
sess = create_session()
+ # test that contains_eager suppresses the normal outer join rendering
q = sess.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses))
- self.assert_compile(q.statement, "SELECT addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "
- "addresses.email_address AS addresses_email_address, users.id AS users_id, users.name AS users_name "\
- "FROM users LEFT OUTER JOIN addresses ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect())
-
+ self.assert_compile(q.with_labels().statement, "SELECT users.id AS users_id, users.name AS users_name, "\
+ "addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\
+ "addresses.email_address AS addresses_email_address FROM users LEFT OUTER JOIN addresses "\
+ "ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect())
+
def go():
assert fixtures.user_address_result == q.all()
self.assert_sql_count(testing.db, go, 1)
sess.clear()
-
+
adalias = addresses.alias()
q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias))
def go():
assert fixtures.user_address_result == q.all()
self.assert_sql_count(testing.db, go, 1)
sess.clear()
-
+
selectquery = users.outerjoin(addresses).select(users.c.id<10, use_labels=True, order_by=[users.c.id, addresses.c.id])
q = sess.query(User)
sess.clear()
+
+ def go():
+ l = q.options(contains_eager(User.addresses)).instances(selectquery.execute())
+ assert fixtures.user_address_result[0:3] == l
+ self.assert_sql_count(testing.db, go, 1)
+ sess.clear()
+
def go():
l = q.options(contains_eager('addresses')).from_statement(selectquery).all()
assert fixtures.user_address_result[0:3] == l
selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id])
sess = create_session()
q = sess.query(User)
-
+
+ # string alias name
def go():
- # test using a string alias name
l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
sess.clear()
+ # expression.Alias object
def go():
- # test using the Alias object itself
l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
sess.clear()
- def decorate(row):
- d = {}
- for c in addresses.c:
- d[c] = row[adalias.corresponding_column(c)]
- return d
-
+ # Aliased object
+ adalias = aliased(Address)
def go():
- # test using a custom 'decorate' function
- l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
- assert fixtures.user_address_result == l
+ l = q.options(contains_eager('addresses', alias=adalias)).outerjoin((adalias, User.addresses)).order_by(User.id, adalias.id)
+ assert fixtures.user_address_result == l.all()
self.assert_sql_count(testing.db, go, 1)
sess.clear()
+
oalias = orders.alias('o1')
ialias = items.alias('i1')
- query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id).order_by(oalias.c.id).order_by(ialias.c.id)
+ query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id, oalias.c.id, ialias.c.id)
q = create_session().query(User)
# test using string alias with more than one level deep
def go():
self.assert_sql_count(testing.db, go, 1)
sess.clear()
+ # test using Aliased with more than one level deep
+ oalias = aliased(Order)
+ ialias = aliased(Item)
+ def go():
+ l = q.options(contains_eager(User.orders, alias=oalias), contains_eager(User.orders, Order.items, alias=ialias)).\
+ outerjoin((oalias, User.orders), (ialias, Order.items)).order_by(User.id, oalias.id, ialias.id)
+ assert fixtures.user_order_result == l.all()
+ self.assert_sql_count(testing.db, go, 1)
+ sess.clear()
+
+
+class MixedEntitiesTest(QueryTest):
+
def test_values(self):
sess = create_session()
+ assert list(sess.query(User).values()) == list()
+
sel = users.select(User.id.in_([7, 8])).alias()
q = sess.query(User)
q2 = q.select_from(sel).values(User.name)
q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address))[1:3].values(User.name, Address.email_address)
self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
- q2 = q.join('addresses', aliased=True).filter(User.name.like('%e%')).values(User.name, Address.email_address)
+ adalias = aliased(Address)
+ q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address)
self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
q2 = q.values(func.count(User.name))
assert q2.next() == (4,)
- u2 = users.alias()
- q2 = q.select_from(sel).filter(u2.c.id>1).order_by([users.c.id, sel.c.id, u2.c.id]).values(users.c.name, sel.c.name, u2.c.name)
+ u2 = aliased(User)
+ q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name)
self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
- q2 = q.select_from(sel).filter(users.c.id>1).values(users.c.name, sel.c.name, User.name)
- self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'ed', u'ed', u'ed')])
+ q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name)
+ self.assertEquals(list(q2), [(u'ed', u'ed', u'ed')])
+
+ # using User.xxx is alised against "sel", so this query returns nothing
+ q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name)
+ self.assertEquals(list(q2), [])
+
+ # whereas this uses users.c.xxx, is not aliased and creates a new join
+ q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name)
+ self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')])
+ def test_tuple_labeling(self):
+ sess = create_session()
+ for row in sess.query(User, Address).join(User.addresses).all():
+ self.assertEquals(set(row.keys()), set(['User', 'Address']))
+ self.assertEquals(row.User, row[0])
+ self.assertEquals(row.Address, row[1])
+
+ for row in sess.query(User.name, User.id.label('foobar')):
+ self.assertEquals(set(row.keys()), set(['name', 'foobar']))
+ self.assertEquals(row.name, row[0])
+ self.assertEquals(row.foobar, row[1])
+
+ for row in sess.query(User).values(User.name, User.id.label('foobar')):
+ self.assertEquals(set(row.keys()), set(['name', 'foobar']))
+ self.assertEquals(row.name, row[0])
+ self.assertEquals(row.foobar, row[1])
+
+ oalias = aliased(Order)
+ for row in sess.query(User, oalias).join(User.orders).all():
+ self.assertEquals(set(row.keys()), set(['User']))
+ self.assertEquals(row.User, row[0])
+
+ oalias = aliased(Order, name='orders')
+ for row in sess.query(User, oalias).join(User.orders).all():
+ self.assertEquals(set(row.keys()), set(['User', 'orders']))
+ self.assertEquals(row.User, row[0])
+ self.assertEquals(row.orders, row[1])
+
+
+ def test_column_queries(self):
+ sess = create_session()
+
+ self.assertEquals(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)])
+
+ sel = users.select(User.id.in_([7, 8])).alias()
+ q = sess.query(User.name)
+ q2 = q.select_from(sel).all()
+ self.assertEquals(list(q2), [(u'jack',), (u'ed',)])
+
+ self.assertEquals(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [
+ (u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'),
+ (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'),
+ (u'fred', u'fred@fred.com')
+ ])
+
+ self.assertEquals(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(),
+ [(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)]
+ )
+
+ self.assertEquals(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(),
+ [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
+ )
+
+ self.assertEquals(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(),
+ [(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))]
+ )
+
+ adalias = aliased(Address)
+ self.assertEquals(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(),
+ [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
+ )
+
+ self.assertEquals(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(),
+ [(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))]
+ )
+
+ # select from aliasing + explicit aliasing
+ self.assertEquals(
+ sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(),
+ [
+ (User(name=u'jack',id=7), u'jack@bean.com'),
+ (User(name=u'ed',id=8), u'ed@wood.com'),
+ (User(name=u'ed',id=8), u'ed@bettyboop.com'),
+ (User(name=u'ed',id=8), u'ed@lala.com'),
+ (User(name=u'fred',id=9), u'fred@fred.com'),
+ (User(name=u'chuck',id=10), None)
+ ]
+ )
+
+ # anon + select from aliasing
+ self.assertEquals(
+ sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(),
+ [
+ User(name=u'ed',id=8),
+ User(name=u'fred',id=9),
+ ]
+ )
+
+ # test eager aliasing, with/without select_from aliasing
+ for q in [
+ sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
+ sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
+ ]:
+ self.assertEquals(
+ q.all(),
+ [(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'),
+ (User(addresses=[
+ Address(user_id=8,email_address=u'ed@wood.com',id=2),
+ Address(user_id=8,email_address=u'ed@bettyboop.com',id=3),
+ Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@wood.com'),
+ (User(addresses=[
+ Address(user_id=8,email_address=u'ed@wood.com',id=2),
+ Address(user_id=8,email_address=u'ed@bettyboop.com',id=3),
+ Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@bettyboop.com'),
+ (User(addresses=[
+ Address(user_id=8,email_address=u'ed@wood.com',id=2),
+ Address(user_id=8,email_address=u'ed@bettyboop.com',id=3),
+ Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@lala.com'),
+ (User(addresses=[Address(user_id=9,email_address=u'fred@fred.com',id=5)],name=u'fred',id=9), u'fred@fred.com'),
+
+ (User(addresses=[],name=u'chuck',id=10), None)]
+ )
+
+ def test_self_referential(self):
+
+ sess = create_session()
+ oalias = aliased(Order)
+
+ for q in [
+ sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id),
+ sess.query(Order, oalias)._from_self().filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id),
+ # here we go....two layers of aliasing
+ sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)),
+
+ # gratuitous four layers
+ sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self()._from_self()._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)),
+
+ ]:
+
+ self.assertEquals(
+ q.all(),
+ [
+ (Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)),
+ (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)),
+ (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3))
+ ]
+ )
+
def test_multi_mappers(self):
test_session = create_session()
(user7, user8, user9, user10) = test_session.query(User).all()
(address1, address2, address3, address4, address5) = test_session.query(Address).all()
- # note the result is a cartesian product
expected = [(user7, address1),
(user8, address2),
(user8, address3),
sess = create_session()
selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
- q = sess.query(User)
- l = q.instances(selectquery.execute(), Address)
- assert l == expected
-
+ self.assertEquals(sess.query(User, Address).instances(selectquery.execute()), expected)
sess.clear()
- for aliased in (False, True):
- q = sess.query(User)
-
- q = q.add_entity(Address).outerjoin('addresses', aliased=aliased)
- l = q.all()
- assert l == expected
+ for address_entity in (Address, aliased(Address)):
+ q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id)
+ self.assertEquals(q.all(), expected)
sess.clear()
- q = sess.query(User).add_entity(Address)
- l = q.join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com').all()
- assert l == [(user8, address3)]
+ q = sess.query(User).add_entity(address_entity)
+ q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
+ self.assertEquals(q.all(), [(user8, address3)])
sess.clear()
- q = sess.query(User, Address).join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com')
- assert q.all() == [(user8, address3)]
+ q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
+ self.assertEquals(q.all(), [(user8, address3)])
sess.clear()
- q = sess.query(User, Address).join('addresses', aliased=aliased).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
+ q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)])
sess.clear()
expected = [(u, u.name) for u in sess.query(User).all()]
- for add_col in (User.name, users.c.name, User.c.name):
+ for add_col in (User.name, users.c.name):
assert sess.query(User).add_column(add_col).all() == expected
sess.clear()
- self.assertRaises(exceptions.InvalidRequestError, sess.query(User).add_column, object())
+ self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object())
- def test_ambiguous_column(self):
- sess = create_session()
-
- q = sess.query(User).join('addresses', aliased=True).join('addresses', aliased=True).add_column(Address.id)
- self.assertRaises(exceptions.InvalidRequestError, iter, q)
-
def test_multi_columns_2(self):
"""test aliased/nonalised joins with the usage of add_column()"""
sess = create_session()
(user10, 0)
]
- for aliased in (False, True):
- q = sess.query(User)
- q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(Address.id).label('count'))
- l = q.all()
- assert l == expected
- sess.clear()
+ q = sess.query(User)
+ q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count'))
+ self.assertEquals(q.all(), expected)
+ sess.clear()
+
+ adalias = aliased(Address)
+ q = sess.query(User)
+ q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count'))
+ self.assertEquals(q.all(), expected)
+ sess.clear()
s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
q = sess.query(User)
assert l == expected
- def test_two_columns(self):
+ def test_raw_columns(self):
sess = create_session()
(user7, user8, user9, user10) = sess.query(User).all()
expected = [
(user9, 1, "Name:fred"),
(user10, 0, "Name:chuck")]
- q = create_session().query(User).add_column(func.count(addresses.c.id))\
- .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=True)\
+ adalias = addresses.alias()
+ q = create_session().query(User).add_column(func.count(adalias.c.id))\
+ .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\
.group_by([c for c in users.c]).order_by(users.c.id)
assert q.all() == expected
assert q.all() == expected
sess.clear()
- # test with outerjoin() both aliased and non
- for aliased in (False, True):
- q = create_session().query(User).add_column(func.count(addresses.c.id))\
- .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=aliased)\
- .group_by([c for c in users.c]).order_by(users.c.id)
+ q = create_session().query(User).add_column(func.count(addresses.c.id))\
+ .add_column(("Name:" + users.c.name)).outerjoin('addresses')\
+ .group_by([c for c in users.c]).order_by(users.c.id)
- assert q.all() == expected
- sess.clear()
+ assert q.all() == expected
+ sess.clear()
+
+ q = create_session().query(User).add_column(func.count(adalias.c.id))\
+ .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\
+ .group_by([c for c in users.c]).order_by(users.c.id)
+
+ assert q.all() == expected
+ sess.clear()
class SelectFromTest(QueryTest):
self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)])
- self.assertEquals(sess.query(User).select_from(sel).filter(User.c.id==8).all(), [User(id=8)])
+ self.assertEquals(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)])
self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [
User(name='jack',id=7), User(name='ed',id=8)
]
)
- self.assertEquals(sess.query(User).select_from(sel).join('addresses', aliased=True).add_entity(Address).order_by(User.id).order_by(Address.id).all(),
+ adalias = aliased(Address)
+ self.assertEquals(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(),
[
(User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)),
(User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)),
sel = users.select(users.c.id.in_([7, 8]))
sess = create_session()
+
+ # TODO: remove
+ sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all()
- self.assertEquals(sess.query(User).select_from(sel).join(['orders', 'items', 'keywords']).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+ self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
User(name=u'jack',id=7)
])
- self.assertEquals(sess.query(User).select_from(sel).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+ self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
User(name=u'jack',id=7)
])
sess.clear()
def go():
- self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.c.id==8).all(),
+ self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).all(),
[User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])]
)
self.assert_sql_count(testing.db, go, 1)
def go():
self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]))
self.assert_sql_count(testing.db, go, 1)
-
+
class CustomJoinTest(QueryTest):
keep_mappers = False
node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
assert node.data=='n12'
+ ret = sess.query(Node.data).join(Node.children, aliased=True).filter_by(data='n122').all()
+ assert ret == [('n12',)]
+
+
node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first()
assert node.data=='n1'
list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)),
[('n122', 'n12', 'n1')])
+
+ def test_join_to_nonaliased(self):
+ sess = create_session()
- def test_any(self):
+ n1 = aliased(Node)
+
+ # using 'n1.parent' implicitly joins to unaliased Node
+ self.assertEquals(
+ sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(),
+ [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
+ )
+
+ # explicit (new syntax)
+ self.assertEquals(
+ sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(),
+ [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
+ )
+
+ def test_multiple_explicit_entities(self):
sess = create_session()
+ parent = aliased(Node)
+ grandparent = aliased(Node)
+ self.assertEquals(
+ sess.query(Node, parent, grandparent).\
+ join((Node.parent, parent), (parent.parent, grandparent)).\
+ filter(Node.data=='n122').filter(parent.data=='n12').\
+ filter(grandparent.data=='n1').first(),
+ (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+ )
+
+ self.assertEquals(
+ sess.query(Node, parent, grandparent).\
+ join((Node.parent, parent), (parent.parent, grandparent)).\
+ filter(Node.data=='n122').filter(parent.data=='n12').\
+ filter(grandparent.data=='n1')._from_self().first(),
+ (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+ )
+
+ self.assertEquals(
+ sess.query(Node, parent, grandparent).\
+ join((Node.parent, parent), (parent.parent, grandparent)).\
+ filter(Node.data=='n122').filter(parent.data=='n12').\
+ filter(grandparent.data=='n1').\
+ options(eagerload(Node.children)).first(),
+ (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+ )
+
+ self.assertEquals(
+ sess.query(Node, parent, grandparent).\
+ join((Node.parent, parent), (parent.parent, grandparent)).\
+ filter(Node.data=='n122').filter(parent.data=='n12').\
+ filter(grandparent.data=='n1')._from_self().\
+ options(eagerload(Node.children)).first(),
+ (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+ )
+
+
+ def test_any(self):
+ sess = create_session()
self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
)
class ExternalColumnsTest(QueryTest):
+ """test mappers with SQL-expressions added as column properties."""
+
keep_mappers = False
def setup_mappers(self):
def test_external_columns_bad(self):
- self.assertRaisesMessage(exceptions.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
+ self.assertRaisesMessage(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
'concat': (users.c.id * 2),
})
clear_mappers()
- self.assertRaisesMessage(exceptions.ArgumentError, "must be given a ColumnElement as its argument.", column_property,
- select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users)
- )
-
def test_external_columns_good(self):
"""test querying mappings that reference external columns or selectables."""
})
mapper(Address, addresses, properties={
- 'user':relation(User, lazy=True)
+ 'user':relation(User)
})
sess = create_session()
-
- l = sess.query(User).all()
- assert [
- User(id=7, concat=14, count=1),
- User(id=8, concat=16, count=3),
- User(id=9, concat=18, count=1),
- User(id=10, concat=20, count=0),
- ] == l
+ sess.query(Address).options(eagerload('user')).all()
+
+ self.assertEquals(sess.query(User).all(),
+ [
+ User(id=7, concat=14, count=1),
+ User(id=8, concat=16, count=3),
+ User(id=9, concat=18, count=1),
+ User(id=10, concat=20, count=0),
+ ]
+ )
address_result = [
Address(id=1, user=User(id=7, concat=14, count=1)),
self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result)
self.assert_sql_count(testing.db, go, 1)
- tuple_address_result = [(address, address.user) for address in address_result]
-
- q =sess.query(Address).join('user', aliased=True, id='ualias').join('user', aliased=True).add_column(User.concat)
- self.assertRaisesMessage(exceptions.InvalidRequestError, "Ambiguous", q.all)
-
- self.assertEquals(sess.query(Address).join('user', aliased=True, id='ualias').add_entity(User, id='ualias').all(), tuple_address_result)
+ ualias = aliased(User)
+ self.assertEquals(
+ sess.query(Address, ualias).join(('user', ualias)).all(),
+ [(address, address.user) for address in address_result]
+ )
- self.assertEquals(sess.query(Address).join('user', aliased=True, id='ualias').join('user', aliased=True).\
- add_column(User.concat, id='ualias').add_column(User.count, id='ualias').all(),
+ self.assertEquals(
+ sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).all(),
+ [
+ (Address(id=1), 1),
+ (Address(id=2), 3),
+ (Address(id=3), 3),
+ (Address(id=4), 3),
+ (Address(id=5), 1)
+ ]
+ )
+
+ self.assertEquals(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).all(),
[
(Address(id=1), 14, 1),
(Address(id=2), 16, 3),
]
)
- self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)),
- [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
+ ua = aliased(User)
+ self.assertEquals(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
+ [
+ (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1),
+ (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3),
+ (Address(id=3, user=User(id=8, concat=16, count=3)), 16, 3),
+ (Address(id=4, user=User(id=8, concat=16, count=3)), 16, 3),
+ (Address(id=5, user=User(id=9, concat=18, count=1)), 18, 1)
+ ]
)
- self.assertEquals(list(sess.query(Address).join('user', aliased=True).values(Address.id, User.id, User.concat, User.count)),
+ self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)),
[(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
)
- ua = aliased(User)
self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)),
[(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
)
import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions, types
+from sqlalchemy import exc as sa_exc, types
from sqlalchemy.orm import *
-from sqlalchemy.orm import collections
+from sqlalchemy.orm import collections, attributes, exc as orm_exc
from sqlalchemy.orm.collections import collection
from testlib import *
from testlib import fixtures
self.pagename = pagename
self.currentversion = PageVersion(self, 1)
def __repr__(self):
- return "Page jobno:%s pagename:%s %s" % (self.jobno, self.pagename, getattr(self, '_instance_key', None))
+ try:
+ state = attributes.instance_state(self)
+ key = state.key
+ except (KeyError, AttributeError):
+ key = None
+ return ("Page jobno:%s pagename:%s %s" %
+ (self.jobno, self.pagename, key))
def add_version(self):
self.currentversion = PageVersion(self, self.currentversion.version+1)
comment = self.add_comment()
try:
sess.flush()
assert False
- except exceptions.AssertionError, e:
+ except AssertionError, e:
assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ")
def test_no_delete_PK_BtoA(self):
try:
sess.flush()
assert False
- except exceptions.AssertionError, e:
+ except AssertionError, e:
assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ")
@testing.fails_on_everything_except('sqlite', 'mysql')
try:
sess.save(a1)
assert False
- except exceptions.AssertionError, err:
+ except AssertionError, err:
assert str(err) == "Attribute 'bs' on class '%s' doesn't handle objects of type '%s'" % (A, C)
def test_o2m_onflush(self):
class A(object):pass
sess.save(a1)
sess.save(b1)
sess.save(c1)
- try:
- sess.flush()
- assert False
- except exceptions.FlushError, err:
- assert str(err).startswith("Attempting to flush an item of type %s on collection 'A.bs (B)', which is handled by mapper 'Mapper|B|b' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ?" % C)
+ self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
+
def test_o2m_nopoly_onflush(self):
class A(object):pass
class B(object):pass
sess.save(a1)
sess.save(b1)
sess.save(c1)
- try:
- sess.flush()
- assert False
- except exceptions.FlushError, err:
- assert str(err).startswith("Attempting to flush an item of type %s on collection 'A.bs (B)', which is handled by mapper 'Mapper|B|b' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ?" % C)
+ self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
def test_m2o_nopoly_onflush(self):
class A(object):pass
sess = create_session()
sess.save(b1)
sess.save(d1)
- try:
- sess.flush()
- assert False
- except exceptions.FlushError, err:
- assert str(err).startswith("Attempting to flush an item of type %s on collection 'D.a (A)', which is handled by mapper 'Mapper|A|a' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ?" % B)
+ self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
+
def test_m2o_oncascade(self):
class A(object):pass
class B(object):pass
d1 = D()
d1.a = b1
sess = create_session()
- try:
- sess.save(d1)
- assert False
- except exceptions.AssertionError, err:
- assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B)
+ self.assertRaisesMessage(AssertionError, "doesn't handle objects of type", sess.save, d1)
class TypedAssociationTable(ORMTest):
def define_tables(self, metadata):
a = sess.query(T1).first()
self.assertEquals(a.t3s, [T3(data='t3')])
+
def test_remote_side_escalation(self):
class T1(fixtures.Base):
't3s':relation(T3, secondary=t2tot3)
})
mapper(T3, t3)
- self.assertRaisesMessage(exceptions.ArgumentError, "Specify remote_side argument", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Specify remote_side argument", compile_mappers)
class ExplicitLocalRemoteTest(ORMTest):
def define_tables(self, metadata):
)
})
mapper(T2, t2)
- self.assertRaises(exceptions.ArgumentError, compile_mappers)
+ self.assertRaises(sa_exc.ArgumentError, compile_mappers)
clear_mappers()
mapper(T1, t1, properties={
)
})
mapper(T2, t2)
- self.assertRaises(exceptions.ArgumentError, compile_mappers)
+ self.assertRaises(sa_exc.ArgumentError, compile_mappers)
class InvalidRelationEscalationTest(ORMTest):
def define_tables(self, metadata):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
def test_no_join_self_ref(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
def test_no_equated(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_no_equated_fks(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
def test_no_equated_self_ref(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_no_equated_self_ref(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
def test_no_equated_viewonly(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_no_equated_self_ref_viewonly(self):
mapper(Foo, foos, properties={
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
def test_no_equated_self_ref_viewonly_fks(self):
mapper(Foo, foos, properties={
'bars':relation(Bar, primaryjoin=foos.c.id==bars.c.fid)
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_equated_self_ref(self):
mapper(Foo, foos, properties={
'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid)
})
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_equated_self_ref_wrong_fks(self):
mapper(Foo, foos, properties={
'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id])
})
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
class InvalidRelationEscalationTestM2M(ORMTest):
def define_tables(self, metadata):
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
def test_no_secondaryjoin(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
def test_bad_primaryjoin(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
def test_bad_secondaryjoin(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
def test_no_equated_secondaryjoin(self):
mapper(Foo, foos, properties={
})
mapper(Bar, bars)
- self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for secondaryjoin condition", compile_mappers)
+ self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for secondaryjoin condition", compile_mappers)
if __name__ == "__main__":
--- /dev/null
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import *
+from testlib import *
+from testlib import fixtures
+
+
+class ScopedSessionTest(ORMTest):
+
+ def define_tables(self, metadata):
+ global table, table2
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)))
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id'))
+ )
+
+ def test_basic(self):
+ Session = scoped_session(sessionmaker())
+
+ class SomeObject(fixtures.Base):
+ query = Session.query_property()
+ class SomeOtherObject(fixtures.Base):
+ query = Session.query_property()
+
+ mapper(SomeObject, table, properties={
+ 'options':relation(SomeOtherObject)
+ })
+ mapper(SomeOtherObject, table2)
+
+ s = SomeObject(id=1, data="hello")
+ sso = SomeOtherObject()
+ s.options.append(sso)
+ Session.save(s)
+ Session.commit()
+ Session.refresh(sso)
+ Session.remove()
+
+ self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), Session.query(SomeObject).one())
+ self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), SomeObject.query.one())
+ self.assertEquals(SomeOtherObject(someid=1), SomeOtherObject.query.filter(SomeOtherObject.someid==sso.someid).one())
+
+
+class ScopedMapperTest(TestBase):
+ def setUpAll(self):
+ global metadata, table, table2
+ metadata = MetaData(testing.db)
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)))
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id'))
+ )
+ metadata.create_all()
+
+ def setUp(self):
+ global SomeObject, SomeOtherObject
+ class SomeObject(fixtures.Base):pass
+ class SomeOtherObject(fixtures.Base):pass
+
+ global Session
+
+ Session = scoped_session(create_session)
+ Session.mapper(SomeObject, table, properties={
+ 'options':relation(SomeOtherObject)
+ })
+ Session.mapper(SomeOtherObject, table2)
+
+ s = SomeObject()
+ s.id = 1
+ s.data = 'hello'
+ sso = SomeOtherObject()
+ s.options.append(sso)
+ Session.flush()
+ Session.clear()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+
+ def tearDown(self):
+ for table in metadata.table_iterator(reverse=True):
+ table.delete().execute()
+ clear_mappers()
+
+ def test_query(self):
+ sso = SomeOtherObject.query().first()
+ assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+
+ def test_query_compiles(self):
+ class Foo(object):
+ pass
+ Session.mapper(Foo, table2)
+ assert hasattr(Foo, 'query')
+
+ ext = MapperExtension()
+
+ class Bar(object):
+ pass
+ Session.mapper(Bar, table2, extension=[ext])
+ assert hasattr(Bar, 'query')
+
+ class Baz(object):
+ pass
+ Session.mapper(Baz, table2, extension=ext)
+ assert hasattr(Baz, 'query')
+
+ def test_validating_constructor(self):
+ s2 = SomeObject(someid=12)
+ s3 = SomeOtherObject(someid=123, bogus=345)
+
+ class ValidatedOtherObject(object): pass
+ Session.mapper(ValidatedOtherObject, table2, validate=True)
+
+ v1 = ValidatedOtherObject(someid=12)
+ self.assertRaises(sa_exc.ArgumentError, ValidatedOtherObject, someid=12, bogus=345)
+
+ def test_dont_clobber_methods(self):
+ class MyClass(object):
+ def expunge(self):
+ return "an expunge !"
+
+ Session.mapper(MyClass, table2)
+
+ assert MyClass().expunge() == "an expunge !"
+
+class ScopedMapperTest2(ORMTest):
+ def define_tables(self, metadata):
+ global table, table2
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)),
+ Column('type', String(30))
+
+ )
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id')),
+ Column('somedata', String(30)),
+ )
+
+ def test_inheritance(self):
+ def expunge_list(l):
+ for x in l:
+ Session.expunge(x)
+ return l
+
+ class BaseClass(fixtures.Base):
+ pass
+ class SubClass(BaseClass):
+ pass
+
+ Session = scoped_session(sessionmaker())
+ Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type)
+ Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass)
+
+ b = BaseClass(data='b1')
+ s = SubClass(data='s1', somedata='somedata')
+ Session.commit()
+ Session.clear()
+
+ assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all()
+ assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all()
+
+
+
+if __name__ == "__main__":
+ testenv.main()
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import *
class Subset(object):
pass
selectable = select(["x", "y", "z"])
- self.assertRaisesMessage(exceptions.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable)
@testing.emits_warning('.*creating an Alias.*')
def test_basic(self):
import testenv; testenv.configure_for_tests()
+import gc
+import pickle
from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes
from sqlalchemy.orm.session import SessionExtension
from sqlalchemy.orm.session import Session as SessionCls
from testlib import *
from testlib.tables import *
from testlib import fixtures, tables
-import pickle
-import gc
class SessionTest(TestBase, AssertsExecutionResults):
pass
def test_close(self):
- """test that flush() doenst close a connection the session didnt open"""
+ """test that flush() doesn't close a connection the session didn't open"""
+
c = testing.db.connect()
class User(object):pass
mapper(User, users)
# then see if expunge fails
session.expunge(u)
+ assert object_session(u) is attributes.instance_state(u).session_id is None
+ for a in u.addresses:
+ assert object_session(a) is attributes.instance_state(a).session_id is None
+
@engines.close_open_connections
def test_binds_from_expression(self):
"""test that Session can extract Table objects from ClauseElements and match them to tables."""
+
Session = sessionmaker(binds={users:testing.db, addresses:testing.db})
sess = Session()
sess.execute(users.insert(), params=dict(user_id=1, user_name='ed'))
conn1 = testing.db.connect()
conn2 = testing.db.connect()
- sess = create_session(transactional=True, bind=conn1)
+ sess = create_session(autocommit=False, bind=conn1)
u = User()
sess.save(u)
sess.flush()
assert testing.db.connect().execute("select count(1) from users").scalar() == 1
sess.close()
- def test_flush_noop(self):
- session = create_session()
- session.uow = object()
-
- self.assertRaises(AttributeError, session.flush)
-
- session = create_session()
- session.uow = object()
-
- session.flush(objects=[])
- session.flush(objects=set())
- session.flush(objects=())
- session.flush(objects=iter([]))
-
@testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang
@engines.close_open_connections
def test_autoflush(self):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
- sess = create_session(bind=conn1, transactional=True, autoflush=True)
+ sess = create_session(bind=conn1, autocommit=False, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
})
mapper(Address, addresses)
- sess = create_session(autoflush=True, transactional=True)
+ sess = create_session(autoflush=True, autocommit=False)
u = User(user_name='ed', addresses=[Address(email_address='foo')])
sess.save(u)
self.assertEquals(sess.query(Address).filter(Address.user==u).one(), Address(email_address='foo'))
mapper(User, users)
try:
- sess = create_session(transactional=True, autoflush=True)
+ sess = create_session(autocommit=False, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
conn1 = testing.db.connect()
conn2 = testing.db.connect()
- sess = create_session(bind=conn1, transactional=True, autoflush=True)
+ sess = create_session(bind=conn1, autocommit=False, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
assert testing.db.connect().execute("select count(1) from users").scalar() == 1
sess.commit()
- # TODO: not doing rollback of attributes right now.
- def dont_test_autoflush_rollback(self):
+ def test_autoflush_rollback(self):
tables.data()
mapper(Address, addresses)
mapper(User, users, properties={
'addresses':relation(Address)
})
- sess = create_session(transactional=True, autoflush=True)
+ sess = create_session(autocommit=False, autoflush=True)
u = sess.query(User).get(8)
newad = Address()
- newad.email_address == 'something new'
+ newad.email_address = 'something new'
u.addresses.append(newad)
u.user_name = 'some new name'
assert u.user_name == 'some new name'
assert u.user_name == 'ed'
assert len(u.addresses) == 3
assert newad not in u.addresses
-
+
+ # pending objects dont get expired
+ assert newad.email_address == 'something new'
+
+ def test_textual_execute(self):
+ """test that Session.execute() converts to text()"""
+
+ tables.data()
+ sess = create_session(bind=testing.db)
+ # use :bindparam style
+ self.assertEquals(sess.execute("select * from users where user_id=:id", {'id':7}).fetchall(), [(7, u'jack')])
@engines.close_open_connections
- def test_external_joined_transaction(self):
+ def test_subtransaction_on_external(self):
class User(object):pass
mapper(User, users)
conn = testing.db.connect()
trans = conn.begin()
- sess = create_session(bind=conn, transactional=True, autoflush=True)
- sess.begin()
+ sess = create_session(bind=conn, autocommit=False, autoflush=True)
+ sess.begin(subtransactions=True)
u = User()
sess.save(u)
sess.flush()
try:
conn = testing.db.connect()
trans = conn.begin()
- sess = create_session(bind=conn, transactional=True, autoflush=True)
+ sess = create_session(bind=conn, autocommit=False, autoflush=True)
u1 = User()
sess.save(u1)
sess.flush()
conn.close()
raise
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @engines.close_open_connections
+ @testing.requires.savepoints
def test_heavy_nesting(self):
session = create_session(bind=testing.db)
session.begin()
session.connection().execute("insert into users (user_name) values ('user1')")
- session.begin()
+ session.begin(subtransactions=True)
session.begin_nested()
assert session.connection().execute("select count(1) from users").scalar() == 2
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.two_phase_transactions
def test_twophase(self):
# TODO: mock up a failure condition here
# to ensure a rollback succeeds
mapper(Address, addresses)
engine2 = create_engine(testing.db.url)
- sess = create_session(transactional=False, autoflush=False, twophase=True)
+ sess = create_session(autocommit=True, autoflush=False, twophase=True)
sess.bind_mapper(User, testing.db)
sess.bind_mapper(Address, engine2)
sess.begin()
assert users.count().scalar() == 1
assert addresses.count().scalar() == 1
- def test_joined_transaction(self):
+ def test_subtransaction_on_noautocommit(self):
class User(object):pass
mapper(User, users)
- sess = create_session(transactional=True, autoflush=True)
- sess.begin()
+ sess = create_session(autocommit=False, autoflush=True)
+ sess.begin(subtransactions=True)
u = User()
sess.save(u)
sess.flush()
assert len(sess.query(User).all()) == 0
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_nested_transaction(self):
class User(object):pass
mapper(User, users)
assert len(sess.query(User).all()) == 1
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_nested_autotrans(self):
class User(object):pass
mapper(User, users)
- sess = create_session(transactional=True)
+ sess = create_session(autocommit=False)
u = User()
sess.save(u)
sess.flush()
assert len(sess.query(User).all()) == 1
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_nested_transaction_connection_add(self):
class User(object): pass
mapper(User, users)
- sess = create_session(transactional=False)
+ sess = create_session(autocommit=True)
sess.begin()
sess.begin_nested()
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_mixed_transaction_control(self):
class User(object): pass
mapper(User, users)
- sess = create_session(transactional=False)
+ sess = create_session(autocommit=True)
sess.begin()
sess.begin_nested()
- transaction = sess.begin()
+ transaction = sess.begin(subtransactions=True)
sess.save(User())
sess.close()
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
+ @testing.requires.savepoints
def test_mixed_transaction_close(self):
class User(object): pass
mapper(User, users)
- sess = create_session(transactional=True)
+ sess = create_session(autocommit=False)
sess.begin_nested()
self.assertEquals(len(sess.query(User).all()), 1)
- @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
- 'oracle', 'maxdb')
- @testing.exclude('mysql', '<', (5, 0, 3))
def test_error_on_using_inactive_session(self):
class User(object): pass
mapper(User, users)
- sess = create_session(transactional=False)
+ sess = create_session(autocommit=True)
- try:
- sess.begin()
- sess.begin()
+ sess.begin()
+ sess.begin(subtransactions=True)
- sess.save(User())
- sess.flush()
+ sess.save(User())
+ sess.flush()
- sess.rollback()
- sess.begin()
- assert False
- except exceptions.InvalidRequestError, e:
- self.assertEquals(str(e), "The transaction is inactive due to a rollback in a subtransaction and should be closed")
+ sess.rollback()
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True)
sess.close()
@engines.close_open_connections
mapper(User, users)
c = testing.db.connect()
sess = create_session(bind=c)
- sess.create_transaction()
+ sess.begin()
transaction = sess.transaction
u = User()
sess.save(u)
sess.flush()
- assert transaction.get_or_add(testing.db) is transaction.get_or_add(c) is c
+ assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c
- try:
- transaction.add(testing.db.connect())
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Session already has a Connection associated for the given Connection's Engine"
-
- try:
- transaction.get_or_add(testing.db.connect())
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Session already has a Connection associated for the given Connection's Engine"
-
- try:
- transaction.add(testing.db)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Session already has a Connection associated for the given Engine"
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect())
transaction.rollback()
assert len(sess.query(User).all()) == 0
mapper(User, users)
c = testing.db.connect()
- sess = create_session(bind=c, transactional=True)
+ sess = create_session(bind=c, autocommit=False)
u = User()
sess.save(u)
sess.flush()
assert not c.in_transaction()
assert c.scalar("select count(1) from users") == 0
- sess = create_session(bind=c, transactional=True)
+ sess = create_session(bind=c, autocommit=False)
u = User()
sess.save(u)
sess.flush()
c = testing.db.connect()
trans = c.begin()
- sess = create_session(bind=c, transactional=False)
+ sess = create_session(bind=c, autocommit=True)
u = User()
sess.save(u)
sess.flush()
user = User()
- try:
- s.update(user)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
-
- try:
- s.delete(user)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "is not persisted", s.update, user)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "is not persisted", s.delete, user)
s.save(user)
s.flush()
assert user in s
assert user not in s.dirty
- try:
- s.save(user)
- assert False
- except exceptions.InvalidRequestError, e:
- assert str(e) == "Instance 'User@%s' is already persistent" % hex(id(user))
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "is already persistent", s.save, user)
s2 = create_session()
- try:
- s2.delete(user)
- assert False
- except exceptions.InvalidRequestError, e:
- assert "is already attached to session" in str(e)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "is already attached to session", s2.delete, user)
u2 = s2.query(User).get(user.user_id)
- try:
- s.delete(u2)
- assert False
- except exceptions.InvalidRequestError, e:
- assert "already persisted with a different identity" in str(e)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "already persisted with a different identity", s.delete, u2)
s.delete(user)
s.flush()
del user
gc.collect()
assert len(s.identity_map) == 0
- assert len(s.identity_map.data) == 0
user = s.query(User).one()
user.user_name = 'fred'
del user
gc.collect()
assert len(s.identity_map) == 1
- assert len(s.identity_map.data) == 1
assert len(s.dirty) == 1
s.flush()
gc.collect()
assert not s.dirty
assert not s.identity_map
- assert not s.identity_map.data
user = s.query(User).one()
assert user.user_name == 'fred'
assert log == ['before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
log = []
- sess = create_session(transactional=True, extension=MyExt())
+ sess = create_session(autocommit=False, extension=MyExt())
u = User()
sess.save(u)
sess.flush()
assert log == ['before_commit', 'after_commit']
log = []
- sess = create_session(transactional=True, extension=MyExt(), bind=testing.db)
+ sess = create_session(autocommit=False, extension=MyExt(), bind=testing.db)
conn = sess.connection()
assert log == ['after_begin']
u1 = User()
sess1.save(u1)
- try:
- sess2.save(u1)
- assert False
- except exceptions.InvalidRequestError, e:
- assert "already attached to session" in str(e)
+ self.assertRaisesMessage(sa_exc.InvalidRequestError, "already attached to session", sess2.save, u1)
u2 = pickle.loads(pickle.dumps(u1))
sess.expunge(u1)
assert u1 not in sess
+ assert Session.object_session(u1) is None
u2 = sess.query(User).get(u1.user_id)
assert u2 is not None and u2 is not u1
sess.expunge(u2)
assert u2 not in sess
+ assert Session.object_session(u2) is None
u1.user_name = "John"
u2.user_name = "Doe"
sess.update(u1)
assert u1 in sess
+ assert Session.object_session(u1) is sess
sess.flush()
assert len(list(sess)) == 1
-class ScopedSessionTest(ORMTest):
-
- def define_tables(self, metadata):
- global table, table2
- table = Table('sometable', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(30)))
- table2 = Table('someothertable', metadata,
- Column('id', Integer, primary_key=True),
- Column('someid', None, ForeignKey('sometable.id'))
- )
-
- def test_basic(self):
- Session = scoped_session(sessionmaker())
-
- class SomeObject(fixtures.Base):
- query = Session.query_property()
- class SomeOtherObject(fixtures.Base):
- query = Session.query_property()
-
- mapper(SomeObject, table, properties={
- 'options':relation(SomeOtherObject)
- })
- mapper(SomeOtherObject, table2)
-
- s = SomeObject(id=1, data="hello")
- sso = SomeOtherObject()
- s.options.append(sso)
- Session.save(s)
- Session.commit()
- Session.remove()
-
- self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), Session.query(SomeObject).one())
- self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), SomeObject.query.one())
- self.assertEquals(SomeOtherObject(someid=1), SomeOtherObject.query.filter(SomeOtherObject.someid==sso.someid).one())
-
-class ScopedMapperTest(TestBase):
+class TLTransactionTest(TestBase):
def setUpAll(self):
- global metadata, table, table2
- metadata = MetaData(testing.db)
- table = Table('sometable', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(30), nullable=False))
- table2 = Table('someothertable', metadata,
- Column('id', Integer, primary_key=True),
- Column('someid', None, ForeignKey('sometable.id'))
- )
- metadata.create_all()
-
- def setUp(self):
- global SomeObject, SomeOtherObject
- class SomeObject(object):pass
- class SomeOtherObject(object):pass
-
- global Session
-
- Session = scoped_session(create_session)
- Session.mapper(SomeObject, table, properties={
- 'options':relation(SomeOtherObject)
- })
- Session.mapper(SomeOtherObject, table2)
-
- s = SomeObject()
- s.id = 1
- s.data = 'hello'
- sso = SomeOtherObject()
- s.options.append(sso)
- Session.flush()
- Session.clear()
-
- def tearDownAll(self):
- metadata.drop_all()
-
+ global users, metadata, tlengine
+ tlengine = create_engine(testing.db.url, strategy='threadlocal')
+ metadata = MetaData()
+ users = Table('query_users', metadata,
+ Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True),
+ Column('user_name', VARCHAR(20)),
+ test_needs_acid=True,
+ )
+ users.create(tlengine)
def tearDown(self):
- for table in metadata.table_iterator(reverse=True):
- table.delete().execute()
- clear_mappers()
-
- def test_query(self):
- sso = SomeOtherObject.query().first()
- assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+ tlengine.execute(users.delete())
- def test_query_compiles(self):
- class Foo(object):
- pass
- Session.mapper(Foo, table2)
- assert hasattr(Foo, 'query')
-
- ext = MapperExtension()
-
- class Bar(object):
- pass
- Session.mapper(Bar, table2, extension=[ext])
- assert hasattr(Bar, 'query')
+ def tearDownAll(self):
+ users.drop(tlengine)
+ tlengine.dispose()
- class Baz(object):
+ @testing.exclude('mysql', '<', (5, 0, 3))
+ def testsessionnesting(self):
+ class User(object):
pass
- Session.mapper(Baz, table2, extension=ext)
- assert hasattr(Baz, 'query')
-
- def test_validating_constructor(self):
- s2 = SomeObject(someid=12)
- s3 = SomeOtherObject(someid=123, bogus=345)
-
- class ValidatedOtherObject(object):pass
- Session.mapper(ValidatedOtherObject, table2, validate=True)
-
- v1 = ValidatedOtherObject(someid=12)
try:
- v2 = ValidatedOtherObject(someid=12, bogus=345)
- assert False
- except exceptions.ArgumentError:
- pass
-
- def test_dont_clobber_methods(self):
- class MyClass(object):
- def expunge(self):
- return "an expunge !"
-
- Session.mapper(MyClass, table2)
-
- assert MyClass().expunge() == "an expunge !"
-
- def _test_autoflush_saveoninit(self, on_init, autoflush=None):
- Session = scoped_session(
- sessionmaker(transactional=True, autoflush=True))
-
- class Foo(object):
- def __init__(self, data=None):
- if autoflush is not None:
- friends = Session.query(Foo).autoflush(autoflush).all()
- else:
- friends = Session.query(Foo).all()
- self.data = data
-
- Session.mapper(Foo, table, save_on_init=on_init)
-
- a1 = Foo('an address')
- Session.flush()
-
- def test_autoflush_saveoninit(self):
- """Test save_on_init + query.autoflush()"""
- self._test_autoflush_saveoninit(False)
- self._test_autoflush_saveoninit(False, True)
- self._test_autoflush_saveoninit(False, False)
-
- self.assertRaises(exceptions.DBAPIError,
- self._test_autoflush_saveoninit, True)
- self.assertRaises(exceptions.DBAPIError,
- self._test_autoflush_saveoninit, True, True)
- self._test_autoflush_saveoninit(True, False)
-
-
-class ScopedMapperTest2(ORMTest):
- def define_tables(self, metadata):
- global table, table2
- table = Table('sometable', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(30)),
- Column('type', String(30))
-
- )
- table2 = Table('someothertable', metadata,
- Column('id', Integer, primary_key=True),
- Column('someid', None, ForeignKey('sometable.id')),
- Column('somedata', String(30)),
- )
-
- def test_inheritance(self):
- def expunge_list(l):
- for x in l:
- Session.expunge(x)
- return l
-
- class BaseClass(fixtures.Base):
- pass
- class SubClass(BaseClass):
- pass
-
- Session = scoped_session(sessionmaker())
- Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type)
- Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass)
-
- b = BaseClass(data='b1')
- s = SubClass(data='s1', somedata='somedata')
- Session.commit()
- Session.clear()
-
- assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all()
- assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all()
+ mapper(User, users)
+ sess = create_session(bind=tlengine)
+ tlengine.begin()
+ u = User()
+ sess.save(u)
+ sess.flush()
+ tlengine.commit()
+ finally:
+ clear_mappers()
if __name__ == "__main__":
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
-from sqlalchemy.orm.session import object_session, Session
-from testlib import *
-
-
-metadata = MetaData()
-users = Table('users', metadata,
- Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
- Column('user_name', String(40)),
-)
-
-class SessionContextTest(TestBase, AssertsExecutionResults):
- def setUp(self):
- clear_mappers()
-
- def do_test(self, class_, context):
- """test session assignment on object creation"""
- obj = class_()
- assert context.current == object_session(obj)
-
- # keep a reference so the old session doesn't get gc'd
- old_session = context.current
-
- context.current = Session()
- assert context.current != object_session(obj)
- assert old_session == object_session(obj)
-
- new_session = context.current
- del context.current
- assert context.current != new_session
- assert old_session == object_session(obj)
-
- obj2 = class_()
- assert context.current == object_session(obj2)
-
- @testing.uses_deprecated('SessionContext')
- def test_mapper_extension(self):
- context = SessionContext(Session)
- class User(object): pass
- User.mapper = mapper(User, users, extension=context.mapper_extension)
- self.do_test(User, context)
-
-
-if __name__ == "__main__":
- testenv.main()
import testenv; testenv.configure_for_tests()
import datetime, os
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import sql
from sqlalchemy.orm import *
from sqlalchemy.orm.shard import ShardedSession
from sqlalchemy.sql import operators
else:
return ids
- create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True)
+ create_session = sessionmaker(class_=ShardedSession, autoflush=True, autocommit=False)
create_session.configure(shards={
'north_america':db1,
for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
sess.save(c)
sess.commit()
-
+ tokyo.city # reload 'city' attribute on tokyo
sess.clear()
assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')]
--- /dev/null
+import testenv; testenv.configure_for_tests()
+import operator
+from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.fixtures import *
+
+
+class TransactionTest(FixtureTest):
+ keep_mappers = True
+ session = sessionmaker()
+
+ def setup_mappers(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user',
+ cascade="all, delete-orphan"),
+ })
+ mapper(Address, addresses)
+
+
+class FixtureDataTest(TransactionTest):
+ refresh_data = True
+
+ def test_attrs_on_rollback(self):
+ sess = self.session()
+ u1 = sess.get(User, 7)
+ u1.name = 'ed'
+ sess.rollback()
+ self.assertEquals(u1.name, 'jack')
+
+ def test_commit_persistent(self):
+ sess = self.session()
+ u1 = sess.get(User, 7)
+ u1.name = 'ed'
+ sess.flush()
+ sess.commit()
+ self.assertEquals(u1.name, 'ed')
+
+ def test_concurrent_commit_persistent(self):
+ s1 = self.session()
+ u1 = s1.get(User, 7)
+ u1.name = 'ed'
+ s1.commit()
+
+ s2 = self.session()
+ u2 = s2.get(User, 7)
+ assert u2.name == 'ed'
+ u2.name = 'will'
+ s2.commit()
+
+ assert u1.name == 'will'
+
+class AutoExpireTest(TransactionTest):
+ tables_only = True
+
+ def test_expunge_pending_on_rollback(self):
+ sess = self.session()
+ u2= User(name='newuser')
+ sess.add(u2)
+ assert u2 in sess
+ sess.rollback()
+ assert u2 not in sess
+
+ def test_trans_pending_cleared_on_commit(self):
+ sess = self.session()
+ u2= User(name='newuser')
+ sess.add(u2)
+ assert u2 in sess
+ sess.commit()
+ assert u2 in sess
+ u3 = User(name='anotheruser')
+ sess.add(u3)
+ sess.rollback()
+ assert u3 not in sess
+ assert u2 in sess
+
+ def test_update_deleted_on_rollback(self):
+ s = self.session()
+ u1 = User(name='ed')
+ s.add(u1)
+ s.commit()
+
+ s.delete(u1)
+ assert u1 in s.deleted
+ s.rollback()
+ assert u1 in s
+ assert u1 not in s.deleted
+
+ def test_trans_deleted_cleared_on_rollback(self):
+ s = self.session()
+ u1 = User(name='ed')
+ s.add(u1)
+ s.commit()
+
+ s.delete(u1)
+ s.commit()
+ assert u1 not in s
+ s.rollback()
+ assert u1 not in s
+
+ def test_update_deleted_on_rollback_cascade(self):
+ s = self.session()
+ u1 = User(name='ed', addresses=[Address(email_address='foo')])
+ s.add(u1)
+ s.commit()
+
+ s.delete(u1)
+ assert u1 in s.deleted
+ assert u1.addresses[0] in s.deleted
+ s.rollback()
+ assert u1 in s
+ assert u1 not in s.deleted
+ assert u1.addresses[0] not in s.deleted
+
+ def test_update_deleted_on_rollback_orphan(self):
+ s = self.session()
+ u1 = User(name='ed', addresses=[Address(email_address='foo')])
+ s.add(u1)
+ s.commit()
+
+ a1 = u1.addresses[0]
+ u1.addresses.remove(a1)
+
+ s.flush()
+ self.assertEquals(s.query(Address).filter(Address.email_address=='foo').all(), [])
+ s.rollback()
+ assert a1 not in s.deleted
+ assert u1.addresses == [a1]
+
+ def test_commit_pending(self):
+ sess = self.session()
+ u1 = User(name='newuser')
+ sess.add(u1)
+ sess.flush()
+ sess.commit()
+ self.assertEquals(u1.name, 'newuser')
+
+
+ def test_concurrent_commit_pending(self):
+ s1 = self.session()
+ u1 = User(name='edward')
+ s1.add(u1)
+ s1.commit()
+
+ s2 = self.session()
+ u2 = s2.query(User).filter(User.name=='edward').one()
+ u2.name = 'will'
+ s2.commit()
+
+ assert u1.name == 'will'
+
+class RollbackRecoverTest(TransactionTest):
+ only_tables = True
+
+ def test_pk_violation(self):
+ s = self.session()
+ a1 = Address(email_address='foo')
+ u1 = User(id=1, name='ed', addresses=[a1])
+ s.add(u1)
+ s.commit()
+
+ a2 = Address(email_address='bar')
+ u2 = User(id=1, name='jack', addresses=[a2])
+
+ u1.name = 'edward'
+ a1.email_address = 'foober'
+ s.add(u2)
+ self.assertRaises(sa_exc.FlushError, s.commit)
+ self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+ s.rollback()
+ assert u2 not in s
+ assert a2 not in s
+ assert u1 in s
+ assert a1 in s
+ assert u1.name == 'ed'
+ assert a1.email_address == 'foo'
+ u1.name = 'edward'
+ a1.email_address = 'foober'
+ s.commit()
+ assert s.query(User).all() == [User(id=1, name='edward', addresses=[Address(email_address='foober')])]
+
+ @testing.requires.savepoints
+ def test_pk_violation_with_savepoint(self):
+ s = self.session()
+ a1 = Address(email_address='foo')
+ u1 = User(id=1, name='ed', addresses=[a1])
+ s.add(u1)
+ s.commit()
+
+ a2 = Address(email_address='bar')
+ u2 = User(id=1, name='jack', addresses=[a2])
+
+ u1.name = 'edward'
+ a1.email_address = 'foober'
+ s.begin_nested()
+ s.add(u2)
+ self.assertRaises(sa_exc.FlushError, s.commit)
+ self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+ s.rollback()
+ assert u2 not in s
+ assert a2 not in s
+ assert u1 in s
+ assert a1 in s
+
+ s.commit()
+ assert s.query(User).all() == [User(id=1, name='edward', addresses=[Address(email_address='foober')])]
+
+
+class SavepointTest(TransactionTest):
+
+ only_tables = True
+
+ @testing.requires.savepoints
+ def test_savepoint_rollback(self):
+ s = self.session()
+ u1 = User(name='ed')
+ u2 = User(name='jack')
+ s.add_all([u1, u2])
+
+ s.begin_nested()
+ u3 = User(name='wendy')
+ u4 = User(name='foo')
+ u1.name = 'edward'
+ u2.name = 'jackward'
+ s.add_all([u3, u4])
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ s.rollback()
+ assert u1.name == 'ed'
+ assert u2.name == 'jack'
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+ s.commit()
+ assert u1.name == 'ed'
+ assert u2.name == 'jack'
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+
+ @testing.requires.savepoints
+ def test_savepoint_commit(self):
+ s = self.session()
+ u1 = User(name='ed')
+ u2 = User(name='jack')
+ s.add_all([u1, u2])
+
+ s.begin_nested()
+ u3 = User(name='wendy')
+ u4 = User(name='foo')
+ u1.name = 'edward'
+ u2.name = 'jackward'
+ s.add_all([u3, u4])
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ s.commit()
+ def go():
+ assert u1.name == 'edward'
+ assert u2.name == 'jackward'
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ self.assert_sql_count(testing.db, go, 1)
+
+ s.commit()
+ self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+
+ @testing.requires.savepoints
+ def test_savepoint_rollback_collections(self):
+ s = self.session()
+ u1 = User(name='ed', addresses=[Address(email_address='foo')])
+ s.add(u1)
+ s.commit()
+
+ u1.name='edward'
+ u1.addresses.append(Address(email_address='bar'))
+ s.begin_nested()
+ u2 = User(name='jack', addresses=[Address(email_address='bat')])
+ s.add(u2)
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ User(name='jack', addresses=[Address(email_address='bat')])
+ ]
+ )
+ s.rollback()
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ ]
+ )
+ s.commit()
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ ]
+ )
+
+ @testing.requires.savepoints
+ def test_savepoint_commit_collections(self):
+ s = self.session()
+ u1 = User(name='ed', addresses=[Address(email_address='foo')])
+ s.add(u1)
+ s.commit()
+
+ u1.name='edward'
+ u1.addresses.append(Address(email_address='bar'))
+ s.begin_nested()
+ u2 = User(name='jack', addresses=[Address(email_address='bat')])
+ s.add(u2)
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ User(name='jack', addresses=[Address(email_address='bat')])
+ ]
+ )
+ s.commit()
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ User(name='jack', addresses=[Address(email_address='bat')])
+ ]
+ )
+ s.commit()
+ self.assertEquals(s.query(User).order_by(User.id).all(),
+ [
+ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+ User(name='jack', addresses=[Address(email_address='bat')])
+ ]
+ )
+
+ @testing.requires.savepoints
+ def test_expunge_pending_on_rollback(self):
+ sess = self.session()
+
+ sess.begin_nested()
+ u2= User(name='newuser')
+ sess.add(u2)
+ assert u2 in sess
+ sess.rollback()
+ assert u2 not in sess
+
+ @testing.requires.savepoints
+ def test_update_deleted_on_rollback(self):
+ s = self.session()
+ u1 = User(name='ed')
+ s.add(u1)
+ s.commit()
+
+ s.begin_nested()
+ s.delete(u1)
+ assert u1 in s.deleted
+ s.rollback()
+ assert u1 in s
+ assert u1 not in s.deleted
+
+
+
+class AutocommitTest(TransactionTest):
+ def test_begin_nested_requires_trans(self):
+ sess = create_session(autocommit=True)
+ self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested)
+
+
+
+if __name__ == '__main__':
+ testenv.main()
import testenv; testenv.configure_for_tests()
import pickleable
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc as sa_exc, sql
from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
from testlib import *
from testlib.tables import *
from testlib import engines, tables, fixtures
# TODO: convert suite to not use Session.mapper, use fixtures.Base
# with explicit session.save()
-Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+Session = scoped_session(sessionmaker(autoflush=True, autocommit=False, autoexpire=False))
orm_mapper = mapper
mapper = Session.mapper
def test_backref(self):
s = Session()
- class User(object):pass
- class Address(object):pass
+ class User(object):
+ def __init__(self, **kw): pass
+ class Address(object):
+ def __init__(self, _sa_session=None): pass
am = mapper(Address, addresses)
m = mapper(User, users, properties = dict(
addresses = relation(am, backref='user', lazy=False))
@engines.close_open_connections
def test_basic(self):
s = Session(scope=None)
- class Foo(object):pass
+ class Foo(object):
+ def __init__(self, value, _sa_session=None):
+ self.value = value
mapper(Foo, version_table, version_id_col=version_table.c.version_id)
f1 = Foo(value='f1', _sa_session=s)
f2 = Foo(value='f2', _sa_session=s)
f1.value='f1rev2'
s.commit()
+
s2 = Session()
f1_s = s2.query(Foo).get(f1.id)
f1_s.value='f1rev3'
s2.commit()
f1.value='f1rev3mine'
- success = False
- try:
- # a concurrent session has modified this, should throw
- # an exception
- s.commit()
- except exceptions.ConcurrentModificationError, e:
- #print e
- success = True
# Only dialects with a sane rowcount can detect the ConcurrentModificationError
if testing.db.dialect.supports_sane_rowcount:
- assert success
-
- s.close()
+ self.assertRaises(orm_exc.ConcurrentModificationError, s.commit)
+ s.rollback()
+ else:
+ s.commit()
+
+ # new in 0.5 ! dont need to close the session
f1 = s.query(Foo).get(f1.id)
f2 = s.query(Foo).get(f2.id)
s.delete(f1)
s.delete(f2)
- success = False
- try:
- s.commit()
- except exceptions.ConcurrentModificationError, e:
- #print e
- success = True
+
if testing.db.dialect.supports_sane_multi_rowcount:
- assert success
+ self.assertRaises(orm_exc.ConcurrentModificationError, s.commit)
+ else:
+ s.commit()
@engines.close_open_connections
def test_versioncheck(self):
"""test that query.with_lockmode performs a 'version check' on an already loaded instance"""
s1 = Session(scope=None)
- class Foo(object):pass
+ class Foo(object):
+ def __init__(self, _sa_session=None): pass
mapper(Foo, version_table, version_id_col=version_table.c.version_id)
- f1s1 =Foo(value='f1', _sa_session=s1)
+ f1s1 = Foo(_sa_session=s1)
+ f1s1.value = 'f1 value'
s1.commit()
s2 = Session()
f1s2 = s2.query(Foo).get(f1s1.id)
f1s2.value='f1 new value'
s2.commit()
- try:
- # load, version is wrong
- s1.query(Foo).with_lockmode('read').get(f1s1.id)
- assert False
- except exceptions.ConcurrentModificationError, e:
- assert True
+ # load, version is wrong
+ self.assertRaises(orm_exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id)
+
# reload it
s1.query(Foo).load(f1s1.id)
# now assert version OK
def test_noversioncheck(self):
"""test that query.with_lockmode works OK when the mapper has no version id col"""
s1 = Session()
- class Foo(object):pass
+ class Foo(object):
+ def __init__(self, _sa_session=None): pass
mapper(Foo, version_table)
- f1s1 =Foo(value='f1', _sa_session=s1)
+ f1s1 =Foo(_sa_session=s1)
+ f1s1.value = 'foo'
f1s1.version_id=0
s1.commit()
s2 = Session()
Session.commit()
Session.close()
f2 = Session.query(Foo).filter_by(id=f1.id).one()
+ assert 'data' in attributes.instance_state(f2).unmodified
assert f2.data == f1.data
f2.data.y = 19
assert f2 in Session.dirty
+ assert 'data' not in attributes.instance_state(f2).unmodified
Session.commit()
Session.close()
f3 = Session.query(Foo).filter_by(id=f1.id).one()
e.multi_rev = 2
Session.commit()
Session.close()
- e2 = Query(Entry).get((e.multi_id, 2))
- self.assert_(e is not e2 and e._instance_key == e2._instance_key)
+ e2 = Session.query(Entry).get((e.multi_id, 2))
+ self.assert_(e is not e2)
+ state = attributes.instance_state(e)
+ state2 = attributes.instance_state(e2)
+ self.assert_(state.key == state2.key)
# this one works with sqlite since we are manually setting up pk values
def test_manualpk(self):
Column('counter', Integer, default=1))
def test_update(self):
- class User(object):
- pass
+ class User(fixtures.Base): pass
mapper(User, users_table)
u = User(name='test')
sess = Session()
self.assert_sql_count(testing.db, go, 1)
def test_multi_update(self):
- class User(object):
- pass
+ class User(fixtures.Base): pass
mapper(User, users_table)
u = User(name='test')
sess = Session()
@testing.unsupported('mssql')
def test_insert(self):
- class User(object):
- pass
+ class User(fixtures.Base): pass
mapper(User, users_table)
u = User(name='test', counter=select([5]))
sess = Session()
'children':relation(MyOtherClass, passive_deletes='all', cascade="all")
})
assert False
- except exceptions.ArgumentError, e:
+ except sa_exc.ArgumentError, e:
assert str(e) == "Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade"
@testing.unsupported('sqlite')
assert myothertable.count().scalar() == 4
mc = sess.query(MyClass).get(mc.id)
sess.delete(mc)
- self.assertRaises(exceptions.DBAPIError, sess.commit)
+ self.assertRaises(sa_exc.DBAPIError, sess.commit)
@testing.unsupported('sqlite')
def test_extra_passive_2(self):
mc = sess.query(MyClass).get(mc.id)
sess.delete(mc)
mc.children[0].data = 'some new data'
- self.assertRaises(exceptions.DBAPIError, sess.commit)
+ self.assertRaises(sa_exc.DBAPIError, sess.commit)
class DefaultTest(ORMTest):
secondary_table.append_column(Column('hoho', hohotype, ForeignKey('default_test.hoho')))
def test_insert(self):
- class Hoho(object):pass
+ class Hoho(fixtures.Base): pass
mapper(Hoho, default_table)
h1 = Hoho(hoho=althohoval)
def test_insert_nopostfetch(self):
# populates the PassiveDefaults explicitly so there is no "post-update"
- class Hoho(object):pass
+ class Hoho(fixtures.Base): pass
mapper(Hoho, default_table)
h1 = Hoho(hoho="15", counter="15")
self.assert_sql_count(testing.db, go, 0)
def test_update(self):
- class Hoho(object):pass
+ class Hoho(fixtures.Base): pass
mapper(Hoho, default_table)
h1 = Hoho()
Session.commit()
def test_o2m_delete_parent(self):
m = mapper(User, users, properties = dict(
- address = relation(mapper(Address, addresses), lazy = True, uselist = False, private = False)
+ address = relation(mapper(Address, addresses), lazy=True, uselist=False)
))
u = User()
a = Address()
Session.commit()
Session.delete(u)
Session.commit()
- self.assert_(a.address_id is not None and a.user_id is None and u._instance_key not in Session.identity_map and a._instance_key in Session.identity_map)
+ self.assert_(a.address_id is not None)
+ self.assert_(a.user_id is None)
+ self.assert_(attributes.instance_state(a).key in Session.identity_map)
+ self.assert_(attributes.instance_state(u).key not in Session.identity_map)
def test_onetoone(self):
m = mapper(User, users, properties = dict(
orm_mapper(T2, t2)
def test_close_transaction_on_commit_fail(self):
- Session = sessionmaker(autoflush=False, transactional=False)
+ Session = sessionmaker(autoflush=False, autocommit=True)
sess = Session()
# with a deferred constraint, this fails at COMMIT time instead
--- /dev/null
+import testenv; testenv.configure_for_tests()
+from sqlalchemy.orm import interfaces, util
+from testlib import *
+from testlib import fixtures
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import mapper
+
+
+class ExtensionCarrierTest(TestBase):
+ def test_basic(self):
+ carrier = util.ExtensionCarrier()
+
+ assert 'translate_row' not in carrier.methods
+ assert carrier.translate_row() is interfaces.EXT_CONTINUE
+ assert 'translate_row' not in carrier.methods
+
+ self.assertRaises(AttributeError, lambda: carrier.snickysnack)
+
+ class Partial(object):
+ def __init__(self, marker):
+ self.marker = marker
+ def translate_row(self, row):
+ return self.marker
+
+ carrier.append(Partial('end'))
+ assert 'translate_row' in carrier.methods
+ assert carrier.translate_row(None) == 'end'
+
+ carrier.push(Partial('front'))
+ assert carrier.translate_row(None) == 'front'
+
+ assert 'populate_instance' not in carrier.methods
+ carrier.append(interfaces.MapperExtension)
+ assert 'populate_instance' in carrier.methods
+
+ assert carrier.interface
+ for m in carrier.interface:
+ assert getattr(interfaces.MapperExtension, m)
+
+class AliasedClassTest(TestBase):
+ def point_map(self, cls):
+ table = Table('point', MetaData(),
+ Column('id', Integer(), primary_key=True),
+ Column('x', Integer),
+ Column('y', Integer))
+ mapper(cls, table)
+ return table
+
+ def test_simple(self):
+ class Point(object):
+ pass
+ table = self.point_map(Point)
+
+ alias = aliased(Point)
+
+ assert alias.id
+ assert alias.x
+ assert alias.y
+
+ assert Point.id.__clause_element__().table is table
+ assert alias.id.__clause_element__().table is not table
+
+ def test_notcallable(self):
+ class Point(object):
+ pass
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ self.assertRaises(TypeError, alias)
+
+ def test_instancemethods(self):
+ class Point(object):
+ def zero(self):
+ self.x, self.y = 0, 0
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.zero
+ assert not getattr(alias, 'zero')
+
+ def test_classmethods(self):
+ class Point(object):
+ @classmethod
+ def max_x(cls):
+ return 100
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.max_x
+ assert alias.max_x
+ assert Point.max_x() == alias.max_x()
+
+ def test_simpleproperties(self):
+ class Point(object):
+ @property
+ def max_x(self):
+ return 100
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.max_x
+ assert Point.max_x != 100
+ assert alias.max_x
+ assert Point.max_x is alias.max_x
+
+ def test_descriptors(self):
+ class descriptor(object):
+ """Tortured..."""
+ def __init__(self, fn):
+ self.fn = fn
+ def __get__(self, obj, owner):
+ if obj is not None:
+ return self.fn(obj, obj)
+ else:
+ return self
+ def method(self):
+ return 'method'
+
+ class Point(object):
+ center = (0, 0)
+ @descriptor
+ def thing(self, arg):
+ return arg.center
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+
+ assert Point.thing != (0, 0)
+ assert Point().thing == (0, 0)
+ assert Point.thing.method() == 'method'
+
+ assert alias.thing != (0, 0)
+ assert alias.thing.method() == 'method'
+
+ def test_hybrid_descriptors(self):
+ from sqlalchemy import Column # override testlib's override
+ import new
+
+ class MethodDescriptor(object):
+ def __init__(self, func):
+ self.func = func
+ def __get__(self, instance, owner):
+ if instance is None:
+ args = (self.func, owner, owner.__class__)
+ else:
+ args = (self.func, instance, owner)
+ return new.instancemethod(*args)
+
+ class PropertyDescriptor(object):
+ def __init__(self, fget, fset, fdel):
+ self.fget = fget
+ self.fset = fset
+ self.fdel = fdel
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.fget(owner)
+ else:
+ return self.fget(instance)
+ def __set__(self, instance, value):
+ self.fset(instance, value)
+ def __delete__(self, instance):
+ self.fdel(instance)
+ hybrid = MethodDescriptor
+ def hybrid_property(fget, fset=None, fdel=None):
+ return PropertyDescriptor(fget, fset, fdel)
+
+ def assert_table(expr, table):
+ for child in expr.get_children():
+ if isinstance(child, Column):
+ assert child.table is table
+
+ class Point(object):
+ def __init__(self, x, y):
+ self.x, self.y = x, y
+ @hybrid
+ def left_of(self, other):
+ return self.x < other.x
+
+ double_x = hybrid_property(lambda self: self.x * 2)
+
+ table = self.point_map(Point)
+ alias = aliased(Point)
+ alias_table = alias.x.__clause_element__().table
+ assert table is not alias_table
+
+ p1 = Point(-10, -10)
+ p2 = Point(20, 20)
+
+ assert p1.left_of(p2)
+ assert p1.double_x == -20
+
+ assert_table(Point.double_x, table)
+ assert_table(alias.double_x, alias_table)
+
+ assert_table(Point.left_of(p2), table)
+ assert_table(alias.left_of(p2), alias_table)
+
+
+if __name__ == '__main__':
+ testenv.main()
+
@profiling.profiled('masseagerload', always=True, sort=['cumulative'])
def masseagerload(session):
+ session.begin()
query = session.query(Item)
l = query.all()
print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
Column('c1', Integer, primary_key=True),
Column('c2', String(30)))
- @profiling.function_call_count(74, {'2.3': 44, '2.4': 42})
+ @profiling.function_call_count(67, {'2.3': 44, '2.4': 42})
def test_insert(self):
t1.insert().compile()
- @profiling.function_call_count(75, {'2.3': 47, '2.4': 42})
+ @profiling.function_call_count(68, {'2.3': 47, '2.4': 42})
def test_update(self):
t1.update().compile()
def test_profile_2_insert(self):
self.test_baseline_2_insert()
- @profiling.function_call_count(4923, {'2.4': 2557})
+ @profiling.function_call_count(4662, {'2.4': 2557})
def test_profile_3_properties(self):
self.test_baseline_3_properties()
def test_profile_5_aggregates(self):
self.test_baseline_5_aggregates()
- @profiling.function_call_count(1988, {'2.4': 1048})
+ @profiling.function_call_count(1882, {'2.4': 1048})
def test_profile_6_editing(self):
self.test_baseline_6_editing()
import sys
from sqlalchemy import *
from testlib import *
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
from sqlalchemy.sql import table, column
def test_literal_interpretation(self):
t = table('test', column('col1'))
- self.assertRaises(exceptions.ArgumentError, case, [("x", "y")])
+ self.assertRaises(exc.ArgumentError, case, [("x", "y")])
self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END")
self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END")
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
from testlib import *
from sqlalchemy import Table, Column # don't use testlib's wrappers
def test_incomplete(self):
c = self.columns()
- self.assertRaises(exceptions.ArgumentError, Table, 't', MetaData(), *c)
+ self.assertRaises(exc.ArgumentError, Table, 't', MetaData(), *c)
def test_incomplete_key(self):
c = Column(Integer)
def test_bogus(self):
- self.assertRaises(exceptions.ArgumentError, Column, 'foo', name='bar')
- self.assertRaises(exceptions.ArgumentError, Column, 'foo', Integer,
+ self.assertRaises(exc.ArgumentError, Column, 'foo', name='bar')
+ self.assertRaises(exc.ArgumentError, Column, 'foo', Integer,
type_=Integer())
if __name__ == "__main__":
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
from testlib import *
from testlib import config, engines
try:
foo.insert().execute(id=2,x=5,y=9)
assert False
- except exceptions.SQLError:
+ except exc.SQLError:
assert True
bar.insert().execute(id=1,x=10)
try:
bar.insert().execute(id=2,x=5)
assert False
- except exceptions.SQLError:
+ except exc.SQLError:
assert True
def test_unique_constraint(self):
try:
foo.insert().execute(id=3, value='value1')
assert False
- except exceptions.SQLError:
+ except exc.SQLError:
assert True
try:
bar.insert().execute(id=3, value='a', value2='b')
assert False
- except exceptions.SQLError:
+ except exc.SQLError:
assert True
def test_index_create(self):
import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions, schema, util
+from sqlalchemy import exc, schema, util
from sqlalchemy.orm import mapper, create_session
from testlib import *
try:
c = ColumnDefault(fn)
assert False, str(fn)
- except exceptions.ArgumentError, e:
+ except exc.ArgumentError, e:
assert str(e) == ex_msg
def test_argsignature(self):
nonai_table.insert().execute(data='row 1')
nonai_table.insert().execute(data='row 2')
assert False
- except exceptions.SQLError, e:
+ except exc.SQLError, e:
print "Got exception", str(e)
assert True
import datetime
from sqlalchemy import *
from sqlalchemy.sql import table, column
-from sqlalchemy import databases, exceptions, sql, util
+from sqlalchemy import databases, sql, util
from sqlalchemy.sql.compiler import BIND_TEMPLATES
from sqlalchemy.engine import default
from sqlalchemy import types as sqltypes
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.sql import table, column, ClauseElement
-from sqlalchemy.sql.expression import _clone
+from sqlalchemy.sql.expression import _clone, _from_objects
from testlib import *
from sqlalchemy.sql.visitors import *
from sqlalchemy import util
def test_clone(self):
struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_a(self, a):
pass
def visit_b(self, b):
pass
vis = Vis()
- s2 = vis.traverse(struct, clone=True)
+ s2 = vis.traverse(struct)
assert struct == s2
assert not struct.is_other(s2)
pass
vis = Vis()
- s2 = vis.traverse(struct, clone=False)
+ s2 = vis.traverse(struct)
assert struct == s2
assert struct.is_other(s2)
struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3"))
struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_a(self, a):
if a.expr == "expr2":
a.expr = "expr2modified"
pass
vis = Vis()
- s2 = vis.traverse(struct, clone=True)
+ s2 = vis.traverse(struct)
assert struct != s2
assert not struct.is_other(s2)
assert struct2 == s2
- class Vis2(ClauseVisitor):
+ class Vis2(CloningVisitor):
def visit_a(self, a):
if a.expr == "expr2b":
a.expr = "expr2bmodified"
pass
vis2 = Vis2()
- s3 = vis2.traverse(struct, clone=True)
+ s3 = vis2.traverse(struct)
assert struct != s3
assert struct3 == s3
def test_binary(self):
clause = t1.c.col2 == t2.c.col2
- assert str(clause) == ClauseVisitor().traverse(clause, clone=True)
+ assert str(clause) == CloningVisitor().traverse(clause)
def test_binary_anon_label_quirk(self):
t = table('t1', column('col1'))
def test_join(self):
clause = t1.join(t2, t1.c.col2==t2.c.col2)
c1 = str(clause)
- assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True))
+ assert str(clause) == str(CloningVisitor().traverse(clause))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_binary(self, binary):
binary.right = t2.c.col3
- clause2 = Vis().traverse(clause, clone=True)
+ clause2 = Vis().traverse(clause)
assert c1 == str(clause)
assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
def test_text(self):
clause = text("select * from table where foo=:bar", bindparams=[bindparam('bar')])
c1 = str(clause)
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_textclause(self, text):
text.text = text.text + " SOME MODIFIER=:lala"
text.bindparams['lala'] = bindparam('lala')
- clause2 = Vis().traverse(clause, clone=True)
+ clause2 = Vis().traverse(clause)
assert c1 == str(clause)
assert str(clause2) == c1 + " SOME MODIFIER=:lala"
assert clause.bindparams.keys() == ['bar']
s2 = select([t1])
s2_assert = str(s2)
s3_assert = str(select([t1], t1.c.col2==7))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_select(self, select):
select.append_whereclause(t1.c.col2==7)
- s3 = Vis().traverse(s2, clone=True)
+ s3 = Vis().traverse(s2)
assert str(s3) == s3_assert
assert str(s2) == s2_assert
print str(s2)
print str(s3)
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col2==7)
Vis().traverse(s2)
assert str(s2) == s3_assert
print "------------------"
s4_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col3==9)))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_select(self, select):
select.append_whereclause(t1.c.col3==9)
- s4 = Vis().traverse(s3, clone=True)
+ s4 = Vis().traverse(s3)
print str(s3)
print str(s4)
assert str(s4) == s4_assert
print "------------------"
s5_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col1==9)))
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_binary(self, binary):
if binary.left is t1.c.col3:
binary.left = t1.c.col1
binary.right = bindparam("col1", unique=True)
- s5 = Vis().traverse(s4, clone=True)
+ s5 = Vis().traverse(s4)
print str(s4)
print str(s5)
assert str(s5) == s5_assert
def test_union(self):
u = union(t1.select(), t2.select())
- u2 = ClauseVisitor().traverse(u, clone=True)
+ u2 = CloningVisitor().traverse(u)
assert str(u) == str(u2)
assert [str(c) for c in u2.c] == [str(c) for c in u.c]
u = union(t1.select(), t2.select())
cols = [str(c) for c in u.c]
- u2 = ClauseVisitor().traverse(u, clone=True)
+ u2 = CloningVisitor().traverse(u)
assert str(u) == str(u2)
assert [str(c) for c in u2.c] == cols
"""test that unique bindparams change their name upon clone() to prevent conflicts"""
s = select([t1], t1.c.col1==bindparam(None, unique=True)).alias()
- s2 = ClauseVisitor().traverse(s, clone=True).alias()
+ s2 = CloningVisitor().traverse(s).alias()
s3 = select([s], s.c.col2==s2.c.col2)
self.assert_compile(s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "\
"WHERE anon_1.col2 = anon_2.col2")
s = select([t1], t1.c.col1==4).alias()
- s2 = ClauseVisitor().traverse(s, clone=True).alias()
+ s2 = CloningVisitor().traverse(s).alias()
s3 = select([s], s.c.col2==s2.c.col2)
self.assert_compile(s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "\
"table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1, "\
subq = t2.select().alias('subq')
s = select([t1.c.col1, subq.c.col1], from_obj=[t1, subq, t1.join(subq, t1.c.col1==subq.c.col2)])
orig = str(s)
- s2 = ClauseVisitor().traverse(s, clone=True)
+ s2 = CloningVisitor().traverse(s)
assert orig == str(s) == str(s2)
- s4 = ClauseVisitor().traverse(s2, clone=True)
+ s4 = CloningVisitor().traverse(s2)
assert orig == str(s) == str(s2) == str(s4)
- s3 = sql_util.ClauseAdapter(table('foo')).traverse(s, clone=True)
+ s3 = sql_util.ClauseAdapter(table('foo')).traverse(s)
assert orig == str(s) == str(s3)
- s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3, clone=True)
+ s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3)
assert orig == str(s) == str(s3) == str(s4)
def test_correlated_select(self):
s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2)
- class Vis(ClauseVisitor):
+ class Vis(CloningVisitor):
def visit_select(self, select):
select.append_whereclause(t1.c.col2==7)
- self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :col2_1")
-
+ self.assert_compile(Vis().traverse(s), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :col2_1")
+
+ def test_this_thing(self):
+ s = select([t1]).where(t1.c.col1=='foo').alias()
+ s2 = select([s.c.col1])
+
+ self.assert_compile(s2, "SELECT anon_1.col1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1")
+ t1a = t1.alias()
+ s2 = sql_util.ClauseAdapter(t1a).traverse(s2)
+ self.assert_compile(s2, "SELECT anon_1.col1 FROM (SELECT table1_1.col1 AS col1, table1_1.col2 AS col2, table1_1.col3 AS col3 FROM table1 AS table1_1 WHERE table1_1.col1 = :col1_1) AS anon_1")
+
+ def test_select_fromtwice(self):
+ t1a = t1.alias()
+
+ s = select([1], t1.c.col1==t1a.c.col1, from_obj=t1a).correlate(t1)
+ self.assert_compile(s, "SELECT 1 FROM table1 AS table1_1 WHERE table1.col1 = table1_1.col1")
+
+ s = CloningVisitor().traverse(s)
+ self.assert_compile(s, "SELECT 1 FROM table1 AS table1_1 WHERE table1.col1 = table1_1.col1")
+
+ s = select([t1]).where(t1.c.col1=='foo').alias()
+
+ s2 = select([1], t1.c.col1==s.c.col1, from_obj=s).correlate(t1)
+ self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1")
+ s2 = ReplacingCloningVisitor().traverse(s2)
+ self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1")
+
class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
def setUpAll(self):
global t1, t2
assert t1alias in s._froms
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
- s = vis.traverse(s, clone=True)
+ s = vis.traverse(s)
+
assert t2alias not in s._froms # not present because it's been cloned
+
assert t1alias in s._froms # present because the adapter placed it there
+
# correlate list on "s" needs to take into account the full _cloned_set for each element in _froms when correlating
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
s = select(['*'], from_obj=[t1alias, t2alias]).correlate(t2alias).as_scalar()
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
- s = vis.traverse(s, clone=True)
+ s = vis.traverse(s)
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
- s = ClauseVisitor().traverse(s, clone=True)
+ s = CloningVisitor().traverse(s)
self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
s = select(['*']).where(t1.c.col1==t2.c.col1).as_scalar()
self.assert_compile(select([t1.c.col1, s]), "SELECT table1.col1, (SELECT * FROM table2 WHERE table1.col1 = table2.col1) AS anon_1 FROM table1")
vis = sql_util.ClauseAdapter(t1alias)
- s = vis.traverse(s, clone=True)
+ s = vis.traverse(s)
self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
- s = ClauseVisitor().traverse(s, clone=True)
+ s = CloningVisitor().traverse(s)
self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
s = select(['*']).where(t1.c.col1==t2.c.col1).correlate(t1).as_scalar()
self.assert_compile(select([t1.c.col1, s]), "SELECT table1.col1, (SELECT * FROM table2 WHERE table1.col1 = table2.col1) AS anon_1 FROM table1")
vis = sql_util.ClauseAdapter(t1alias)
- s = vis.traverse(s, clone=True)
+ s = vis.traverse(s)
self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
- s = ClauseVisitor().traverse(s, clone=True)
+ s = CloningVisitor().traverse(s)
self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
-
+
+ @testing.fails_on_everything_except()
+ def test_joins_dont_adapt(self):
+ # adapting to a join, i.e. ClauseAdapter(t1.join(t2)), doesn't make much sense.
+ # ClauseAdapter doesn't make any changes if it's against a straight join.
+ users = table('users', column('id'))
+ addresses = table('addresses', column('id'), column('user_id'))
+
+ ualias = users.alias()
+
+ s = select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users) #.as_scalar().label(None)
+ s= sql_util.ClauseAdapter(ualias).traverse(s)
+
+ j1 = addresses.join(ualias, addresses.c.user_id==ualias.c.id)
+
+ self.assert_compile(sql_util.ClauseAdapter(j1).traverse(s), "SELECT count(addresses.id) AS count_1 FROM addresses WHERE users_1.id = addresses.user_id")
def test_table_to_alias(self):
t1alias = t1.alias('t1alias')
vis = sql_util.ClauseAdapter(t1alias)
- ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
- assert ff._get_from_objects() == [t1alias]
+ ff = vis.traverse(func.count(t1.c.col1).label('foo'))
+ assert list(_from_objects(ff)) == [t1alias]
- self.assert_compile(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], from_obj=[t1])), "SELECT * FROM table1 AS t1alias")
+ self.assert_compile(select(['*'], t1.c.col1==t2.c.col2), "SELECT * FROM table1, table2 WHERE table1.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2)), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2])), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1)), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2)), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
s = select(['*'], from_obj=[t1]).alias('foo')
self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo")
- self.assert_compile(vis.traverse(s.select(), clone=True), "SELECT foo.* FROM (SELECT * FROM table1 AS t1alias) AS foo")
+ self.assert_compile(vis.traverse(s.select()), "SELECT foo.* FROM (SELECT * FROM table1 AS t1alias) AS foo")
self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo")
- ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
- self.assert_compile(ff, "count(t1alias.col1) AS foo")
- assert ff._get_from_objects() == [t1alias]
+ ff = vis.traverse(func.count(t1.c.col1).label('foo'))
+ self.assert_compile(select([ff]), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias")
+ assert list(_from_objects(ff)) == [t1alias]
# TODO:
# self.assert_compile(vis.traverse(select([func.count(t1.c.col1).label('foo')]), clone=True), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias")
t2alias = t2.alias('t2alias')
vis.chain(sql_util.ClauseAdapter(t2alias))
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
- self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2)), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2])), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1)), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+ self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2)), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
def test_include_exclude(self):
m = MetaData()
"WHERE c.bid = anon_1.b_aid"
)
+class SpliceJoinsTest(TestBase, AssertsCompiledSQL):
+ def setUpAll(self):
+ global table1, table2, table3, table4
+ def _table(name):
+ return table(name, column("col1"), column("col2"),column("col3"))
+
+ table1, table2, table3, table4 = [_table(name) for name in ("table1", "table2", "table3", "table4")]
+
+ def test_splice(self):
+ (t1, t2, t3, t4) = (table1, table2, table1.alias(), table2.alias())
+
+ j = t1.join(t2, t1.c.col1==t2.c.col1).join(t3, t2.c.col1==t3.c.col1).join(t4, t4.c.col1==t1.c.col1)
+
+ s = select([t1]).where(t1.c.col2<5).alias()
+
+ self.assert_compile(sql_util.splice_joins(s, j),
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, "\
+ "table1.col3 AS col3 FROM table1 WHERE table1.col2 < :col2_1) AS anon_1 "\
+ "JOIN table2 ON anon_1.col1 = table2.col1 JOIN table1 AS table1_1 ON table2.col1 = table1_1.col1 "\
+ "JOIN table2 AS table2_1 ON table2_1.col1 = anon_1.col1")
+
+ def test_stop_on(self):
+ (t1, t2, t3) = (table1, table2, table3)
+
+ j1= t1.join(t2, t1.c.col1==t2.c.col1)
+ j2 = j1.join(t3, t2.c.col1==t3.c.col1)
+
+ s = select([t1]).select_from(j1).alias()
+
+ self.assert_compile(sql_util.splice_joins(s, j2),
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 JOIN table2 "\
+ "ON table1.col1 = table2.col1) AS anon_1 JOIN table2 ON anon_1.col1 = table2.col1 JOIN table3 "\
+ "ON table2.col1 = table3.col1"
+ )
+
+ self.assert_compile(sql_util.splice_joins(s, j2, j1),
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 "\
+ "JOIN table2 ON table1.col1 = table2.col1) AS anon_1 JOIN table3 ON table2.col1 = table3.col1")
+
+ def test_splice_2(self):
+ t2a = table2.alias()
+ t3a = table3.alias()
+ j1 = table1.join(t2a, table1.c.col1==t2a.c.col1).join(t3a, t2a.c.col2==t3a.c.col2)
+
+ t2b = table4.alias()
+ j2 = table1.join(t2b, table1.c.col3==t2b.c.col3)
+
+ self.assert_compile(sql_util.splice_joins(table1, j1),
+ "table1 JOIN table2 AS table2_1 ON table1.col1 = table2_1.col1 "\
+ "JOIN table3 AS table3_1 ON table2_1.col2 = table3_1.col2")
+
+ self.assert_compile(sql_util.splice_joins(table1, j2), "table1 JOIN table4 AS table4_1 ON table1.col3 = table4_1.col3")
+
+ self.assert_compile(sql_util.splice_joins(sql_util.splice_joins(table1, j1), j2),
+ "table1 JOIN table2 AS table2_1 ON table1.col1 = table2_1.col1 "\
+ "JOIN table3 AS table3_1 ON table2_1.col2 = table3_1.col2 "\
+ "JOIN table4 AS table4_1 ON table1.col3 = table4_1.col3")
+
+
class SelectTest(TestBase, AssertsCompiledSQL):
"""tests the generative capability of Select"""
import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
from sqlalchemy.engine import default
from testlib import *
try:
print r['user_id']
assert False
- except exceptions.InvalidRequestError, e:
+ except exc.InvalidRequestError, e:
assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \
str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement."
def test_cant_execute_join(self):
try:
users.join(addresses).execute()
- except exceptions.ArgumentError, e:
+ except exc.ArgumentError, e:
assert str(e).startswith('Not an executable clause: ')
from sqlalchemy.sql import compiler
from testlib import *
-class QuoteTest(TestBase):
+class QuoteTest(TestBase, AssertsCompiledSQL):
def setUpAll(self):
# TODO: figure out which databases/which identifiers allow special characters to be used,
# such as: spaces, quote characters, punctuation characters, set up tests for those as
res2 = select([table2.c.d123, table2.c.u123, table2.c.MixedCase], use_labels=True).execute().fetchall()
print res2
assert(res2==[(1,2,3),(2,2,3),(4,3,2)])
+
+ def test_quote_flag(self):
+ metadata = MetaData()
+ t1 = Table('TableOne', metadata,
+ Column('ColumnOne', Integer), schema="FooBar")
+ self.assert_compile(t1.select(), '''SELECT "FooBar"."TableOne"."ColumnOne" FROM "FooBar"."TableOne"''')
+
+ metadata = MetaData()
+ t1 = Table('t1', metadata,
+ Column('col1', Integer, quote=True), quote=True, schema="foo", quote_schema=True)
+ self.assert_compile(t1.select(), '''SELECT "foo"."t1"."col1" FROM "foo"."t1"''')
+ metadata = MetaData()
+ t1 = Table('TableOne', metadata,
+ Column('ColumnOne', Integer, quote=False), quote=False, schema="FooBar", quote_schema=False)
+ self.assert_compile(t1.select(), '''SELECT FooBar.TableOne.ColumnOne FROM FooBar.TableOne''')
+
@testing.unsupported('oracle')
def testlabels(self):
"""test the quoting of labels.
table = Table("ImATable", metadata,
Column("col1", Integer))
x = select([table.c.col1.label("ImATable_col1")]).alias("SomeAlias")
- assert str(select([x.c.ImATable_col1])) == '''SELECT "SomeAlias"."ImATable_col1" \nFROM (SELECT "ImATable".col1 AS "ImATable_col1" \nFROM "ImATable") AS "SomeAlias"'''
+ self.assert_compile(select([x.c.ImATable_col1]),
+ '''SELECT "SomeAlias"."ImATable_col1" FROM (SELECT "ImATable".col1 AS "ImATable_col1" FROM "ImATable") AS "SomeAlias"''')
# note that 'foo' and 'FooCol' are literals already quoted
x = select([sql.literal_column("'foo'").label("somelabel")], from_obj=[table]).alias("AnAlias")
x = x.select()
- assert str(x) == '''SELECT "AnAlias".somelabel \nFROM (SELECT 'foo' AS somelabel \nFROM "ImATable") AS "AnAlias"'''
+ self.assert_compile(x,
+ '''SELECT "AnAlias".somelabel FROM (SELECT 'foo' AS somelabel FROM "ImATable") AS "AnAlias"''')
x = select([sql.literal_column("'FooCol'").label("SomeLabel")], from_obj=[table])
x = x.select()
- assert str(x) == '''SELECT "SomeLabel" \nFROM (SELECT 'FooCol' AS "SomeLabel" \nFROM "ImATable")'''
+ self.assert_compile(x,
+ '''SELECT "SomeLabel" FROM (SELECT 'FooCol' AS "SomeLabel" FROM "ImATable")''')
class PreparerTest(TestBase):
import testenv; testenv.configure_for_tests()
import datetime, re, operator
from sqlalchemy import *
-from sqlalchemy import exceptions, sql, util
+from sqlalchemy import exc, sql, util
from sqlalchemy.sql import table, column, compiler
from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
from testlib import *
t2 = table('t2', column('c'), column('d'))
s = select([t.c.a]).where(t.c.a==t2.c.d).as_scalar()
s2 =select([t, t2, s])
- self.assertRaises(exceptions.InvalidRequestError, str, s2)
+ self.assertRaises(exc.InvalidRequestError, str, s2)
# intentional again
s = s.correlate(t, t2)
try:
s = select([table1.c.myid, table1.c.name]).as_scalar()
assert False
- except exceptions.InvalidRequestError, err:
+ except exc.InvalidRequestError, err:
assert str(err) == "Scalar select can only be created from a Select object that has exactly one column expression.", str(err)
try:
# generic function which will look at the type of expression
func.coalesce(select([table1.c.myid]))
assert False
- except exceptions.InvalidRequestError, err:
+ except exc.InvalidRequestError, err:
assert str(err) == "Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.", str(err)
s = select([table1.c.myid], scalar=True, correlate=False)
s = select([table1.c.myid]).as_scalar()
try:
s.c.foo
- except exceptions.InvalidRequestError, err:
+ except exc.InvalidRequestError, err:
assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
try:
s.columns.foo
- except exceptions.InvalidRequestError, err:
+ except exc.InvalidRequestError, err:
assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
zips = table('zips',
self.assert_compile(
select(
- [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid)
- ]),
+ [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid)]
+ ),
"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid"
)
def test_compound_selects(self):
try:
union(table3.select(), table1.select())
- except exceptions.ArgumentError, err:
+ except exc.ArgumentError, err:
assert str(err) == "All selectables passed to CompoundSelect must have identical numbers of columns; select #1 has 2 columns, select #2 has 3"
x = union(
# check that conflicts with "unique" params are caught
s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('myid_1')))
- self.assertRaisesMessage(exceptions.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+ self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('myid_1')))
- self.assertRaisesMessage(exceptions.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+ self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
self.assert_compile(select([table1], table1.c.myid.in_([])),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
- @testing.uses_deprecated('passing in_')
- def test_in_deprecated_api(self):
- self.assert_compile(select([table1], table1.c.myid.in_('abc')),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)")
-
- self.assert_compile(select([table1], table1.c.myid.in_(1)),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)")
-
- self.assert_compile(select([table1], table1.c.myid.in_(1,2)),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :myid_2)")
-
- self.assert_compile(select([table1], table1.c.myid.in_()),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
-
def test_cast(self):
tbl = table('casttest',
column('id', Integer),
from sqlalchemy import *
from testlib import *
from sqlalchemy.sql import util as sql_util
-from sqlalchemy import exceptions
+from sqlalchemy import exc
metadata = MetaData()
table = Table('table1', metadata,
print str(j)
self.assert_(criterion.compare(j.onclause))
- def testcolumnlabels(self):
+ def test_column_labels(self):
a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')])
print str(a)
print [c for c in a.columns]
criterion = a.c.acol1 == table2.c.col2
print str(j)
self.assert_(criterion.compare(j.onclause))
-
+
def test_labeled_select_correspoinding(self):
l1 = select([func.max(table.c.col1)]).label('foo')
s = select([l1])
assert s.corresponding_column(l1).name == s.c.foo
-
+
s = select([table.c.col1, l1])
assert s.corresponding_column(l1).name == s.c.foo
print str(j.onclause)
self.assert_(criterion.compare(j.onclause))
- def testtablejoinedtoselectoftable(self):
+ def test_table_joined_to_select_of_table(self):
metadata = MetaData()
a = Table('a', metadata,
Column('id', Integer, primary_key=True))
s = select([t2, t3], use_labels=True)
- self.assertRaises(exceptions.NoReferencedTableError, s.join, t1)
+ self.assertRaises(exc.NoReferencedTableError, s.join, t1)
class PrimaryKeyTest(TestBase, AssertsExecutionResults):
def test_join_pk_collapse_implicit(self):
import testenv; testenv.configure_for_tests()
import datetime, os, pickleable, re
from sqlalchemy import *
-from sqlalchemy import exceptions, types, util
+from sqlalchemy import exc, types, util
from sqlalchemy.sql import operators
import sqlalchemy.engine.url as url
from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
assert isinstance(dialect_type, mssql.MSNVarchar)
assert dialect_type.get_col_spec() == 'NVARCHAR(10)'
- def testoracletext(self):
- dialect = oracle.OracleDialect()
- class MyDecoratedType(types.TypeDecorator):
- impl = String
- def copy(self):
- return MyDecoratedType()
-
- col = Column('', MyDecoratedType)
- dialect_type = col.type.dialect_impl(dialect)
- assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl)
-
def testoracletimestamp(self):
dialect = oracle.OracleDialect()
firebird_dialect = firebird.FBDialect()
for dialect, start, test in [
- (oracle_dialect, String(), oracle.OracleText),
+ (oracle_dialect, String(), oracle.OracleString),
(oracle_dialect, VARCHAR(), oracle.OracleString),
(oracle_dialect, String(50), oracle.OracleString),
- (oracle_dialect, Unicode(), oracle.OracleText),
+ (oracle_dialect, Unicode(), oracle.OracleString),
(oracle_dialect, UnicodeText(), oracle.OracleText),
(oracle_dialect, NCHAR(), oracle.OracleString),
(oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw),
- (mysql_dialect, String(), mysql.MSText),
+ (mysql_dialect, String(), mysql.MSString),
(mysql_dialect, VARCHAR(), mysql.MSString),
(mysql_dialect, String(50), mysql.MSString),
- (mysql_dialect, Unicode(), mysql.MSText),
+ (mysql_dialect, Unicode(), mysql.MSString),
(mysql_dialect, UnicodeText(), mysql.MSText),
(mysql_dialect, NCHAR(), mysql.MSNChar),
- (postgres_dialect, String(), postgres.PGText),
+ (postgres_dialect, String(), postgres.PGString),
(postgres_dialect, VARCHAR(), postgres.PGString),
(postgres_dialect, String(50), postgres.PGString),
- (postgres_dialect, Unicode(), postgres.PGText),
+ (postgres_dialect, Unicode(), postgres.PGString),
(postgres_dialect, UnicodeText(), postgres.PGText),
(postgres_dialect, NCHAR(), postgres.PGString),
- (firebird_dialect, String(), firebird.FBText),
+ (firebird_dialect, String(), firebird.FBString),
(firebird_dialect, VARCHAR(), firebird.FBString),
(firebird_dialect, String(50), firebird.FBString),
- (firebird_dialect, Unicode(), firebird.FBText),
+ (firebird_dialect, Unicode(), firebird.FBString),
(firebird_dialect, UnicodeText(), firebird.FBText),
(firebird_dialect, NCHAR(), firebird.FBString),
]:
def testprocessing(self):
global users
- users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12)
- users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15)
- users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9)
+ users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12)
+ users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15)
+ users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9)
l = users.select().execute().fetchall()
for assertstr, assertint, assertint2, row in zip(
l
):
- for col in row[1:8]:
+ for col in row[1:7]:
self.assertEquals(col, assertstr)
- self.assertEquals(row[8], assertint)
- self.assertEquals(row[9], assertint2)
- for col in (row[4], row[5], row[7]):
+ self.assertEquals(row[7], assertint)
+ self.assertEquals(row[8], assertint2)
+ for col in (row[3], row[4], row[6]):
assert isinstance(col, unicode)
def setUpAll(self):
# decorated type with an argument, so its a String
Column('goofy2', MyDecoratedType(50), nullable = False),
- # decorated type without an argument, it will adapt_args to TEXT
- Column('goofy3', MyDecoratedType, nullable = False),
-
- Column('goofy4', MyUnicodeType, nullable = False),
- Column('goofy5', LegacyUnicodeType, nullable = False),
+ Column('goofy4', MyUnicodeType(50), nullable = False),
+ Column('goofy5', LegacyUnicodeType(50), nullable = False),
Column('goofy6', LegacyType, nullable = False),
- Column('goofy7', MyNewUnicodeType, nullable = False),
+ Column('goofy7', MyNewUnicodeType(50), nullable = False),
Column('goofy8', MyNewIntType, nullable = False),
Column('goofy9', MyNewIntSubClass, nullable = False),
try:
unicode_table.insert().execute(unicode_varchar='not unicode')
assert False
- except exceptions.SAWarning, e:
+ except exc.SAWarning, e:
assert str(e) == "Unicode type received non-unicode bind param value 'not unicode'", str(e)
unicode_engine = engines.utf8_engine(options={'convert_unicode':True,
try:
unicode_engine.execute(unicode_table.insert(), plain_varchar='im not unicode')
assert False
- except exceptions.InvalidRequestError, e:
+ except exc.InvalidRequestError, e:
assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'"
@testing.emits_warning('.*non-unicode bind')
t.drop(checkfirst=True)
class StringTest(TestBase, AssertsExecutionResults):
- def test_nolen_string_deprecated(self):
+
+
+ def test_nolength_string(self):
+ # this tests what happens with String DDL with no length.
+ # seems like we need to decide amongst "VARCHAR" (sqlite, postgres), "TEXT" (mysql)
+ # i.e. theres some inconsisency here.
+
metadata = MetaData(testing.db)
foo =Table('foo', metadata,
Column('one', String))
-
- # no warning
- select([func.count("*")], bind=testing.db).execute()
-
- try:
- # warning during CREATE
- foo.create()
- assert False
- except exceptions.SADeprecationWarning, e:
- assert "Using String type with no length" in str(e)
- assert re.search(r'\bone\b', str(e))
-
- bar = Table('bar', metadata, Column('one', String(40)))
-
- try:
- # no warning
- bar.create()
-
- # no warning for non-lengthed string
- select([func.count("*")], from_obj=bar).execute()
- finally:
- bar.drop()
-
+
+ foo.create()
+ foo.drop()
+
def _missing_decimal():
"""Python implementation supports decimals"""
try:
Load after sqlalchemy imports to use instrumented stand-ins like Table.
"""
+import sys
import testlib.config
from testlib.schema import Table, Column
from testlib.orm import mapper
import testlib.testing as testing
-from testlib.testing import rowset
-from testlib.testing import TestBase, AssertsExecutionResults, ORMTest, AssertsCompiledSQL, ComparesTables
+from testlib.testing import \
+ AssertsCompiledSQL, \
+ AssertsExecutionResults, \
+ ComparesTables, \
+ ORMTest, \
+ TestBase, \
+ rowset
import testlib.profiling as profiling
import testlib.engines as engines
+import testlib.requires as requires
from testlib.compat import set, frozenset, sorted, _function_named
'mapper',
'Table', 'Column',
'rowset',
- 'TestBase', 'AssertsExecutionResults', 'ORMTest', 'AssertsCompiledSQL', 'ComparesTables',
+ 'TestBase', 'AssertsExecutionResults', 'ORMTest',
+ 'AssertsCompiledSQL', 'ComparesTables',
'profiling', 'engines',
'set', 'frozenset', 'sorted', '_function_named')
+
+
+testing.requires = requires
+
+sys.modules['testlib.sa'] = sa = testing.CompositeModule(
+ 'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule(
+ 'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm'))
+sys.modules['testlib.sa.orm'] = sa.orm
-import itertools, new, sys, warnings
+import new
-__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque'
+__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque', 'reversed'
try:
set = set
l.sort()
return l
+try:
+ reversed = reversed
+except NameError:
+ def reversed(seq):
+ i = len(seq) - 1
+ while i >= 0:
+ yield seq[i]
+ i -= 1
+ raise StopIteration()
+
try:
from collections import deque
except ImportError:
def popleft(self):
return self.pop(0)
def extendleft(self, iterable):
- items = list(iterable)
- items.reverse()
- for x in items:
+ for x in reversed(list(iterable)):
self.insert(0, x)
def _function_named(fn, newname):
import sys, types, weakref
from testlib import config
-from testlib.compat import *
+from testlib.compat import set, _function_named, deque
class ConnectionKiller(object):
"""
import sys
-from StringIO import StringIO
-from tokenize import *
+from tokenize import generate_tokens, INDENT, DEDENT, NAME, OP, NL, NEWLINE, \
+ NUMBER, STRING, COMMENT
__all__ = ['py23_decorators', 'py23']
-# can't be imported until the path is setup; be sure to configure
-# first if covering.
-from sqlalchemy import *
-from sqlalchemy import util
-from testlib import *
-
-__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest', 'Dingaling', 'item_keywords',
- 'dingalings', 'User', 'items', 'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users',
+from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey
+from testlib.sa.orm import attributes
+from testlib import ORMTest
+from testlib.compat import set
+
+
+__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest',
+ 'Dingaling', 'item_keywords', 'dingalings', 'User', 'items',
+ 'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users',
'order_items', 'Item', 'Order', 'fixtures']
-
-_recursion_stack = util.Set()
+
+
+_recursion_stack = set()
class Base(object):
def __init__(self, **kwargs):
for k in kwargs:
_recursion_stack.add(self)
try:
# pick the entity thats not SA persisted as the source
+ try:
+ state = attributes.instance_state(self)
+ key = state.key
+ except (KeyError, AttributeError):
+ key = None
if other is None:
a = self
b = other
- elif hasattr(self, '_instance_key'):
+ elif key is not None:
a = other
b = self
else:
battr = getattr(b, attr)
except AttributeError:
#print "b class does not have attribute named '%s'" % attr
+ #raise
return False
-
+
if list(value) == list(battr):
continue
else:
users = Table('users', metadata,
Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False))
+ Column('name', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
+ )
orders = Table('orders', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', None, ForeignKey('users.id')),
Column('address_id', None, ForeignKey('addresses.id')),
Column('description', String(30)),
- Column('isopen', Integer)
+ Column('isopen', Integer),
+ test_needs_acid=True,
+ test_needs_fk=True
)
addresses = Table('addresses', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', None, ForeignKey('users.id')),
- Column('email_address', String(50), nullable=False))
+ Column('email_address', String(50), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True)
dingalings = Table("dingalings", metadata,
Column('id', Integer, primary_key=True),
Column('address_id', None, ForeignKey('addresses.id')),
- Column('data', String(30))
+ Column('data', String(30)),
+ test_needs_acid=True,
+ test_needs_fk=True
)
items = Table('items', metadata,
Column('id', Integer, primary_key=True),
- Column('description', String(30), nullable=False)
+ Column('description', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
)
order_items = Table('order_items', metadata,
Column('item_id', None, ForeignKey('items.id')),
- Column('order_id', None, ForeignKey('orders.id')))
+ Column('order_id', None, ForeignKey('orders.id')),
+ test_needs_acid=True,
+ test_needs_fk=True)
item_keywords = Table('item_keywords', metadata,
Column('item_id', None, ForeignKey('items.id')),
- Column('keyword_id', None, ForeignKey('keywords.id')))
+ Column('keyword_id', None, ForeignKey('keywords.id')),
+ test_needs_acid=True,
+ test_needs_fk=True)
keywords = Table('keywords', metadata,
Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False)
+ Column('name', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
)
def install_fixture_data():
class FixtureTest(ORMTest):
refresh_data = False
-
+ only_tables = False
+
def setUpAll(self):
super(FixtureTest, self).setUpAll()
- if self.keep_data:
+ if not self.only_tables and self.keep_data:
install_fixture_data()
def setUp(self):
- if self.refresh_data:
+ if not self.only_tables and self.refresh_data:
install_fixture_data()
def define_tables(self, meta):
"""Profiling support for unit and performance tests."""
import os, sys
-from testlib.config import parser, post_configure
-from testlib.compat import *
+from testlib.compat import set, _function_named
import testlib.config
__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
configuration and command-line options.
"""
- import time, hotshot, hotshot.stats
-
# manual or automatic namespacing by module would remove conflict issues
if target is None:
target = 'anonymous_target'
--- /dev/null
+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+"""
+from testlib import testing
+
+def savepoints(fn):
+ """Target database must support savepoints."""
+ return (testing.unsupported(
+ 'access',
+ 'mssql',
+ 'sqlite',
+ 'sybase',
+ )
+ (testing.exclude('mysql', '<', (5, 0, 3))
+ (fn)))
+
+def two_phase_transactions(fn):
+ """Target database must support two-phase transactions."""
+ return (testing.unsupported(
+ 'access',
+ 'firebird',
+ 'maxdb',
+ 'mssql',
+ 'oracle',
+ 'sqlite',
+ 'sybase',
+ )
+ (testing.exclude('mysql', '<', (5, 0, 3))
+ (fn)))
from testlib import testing
-import itertools
+
schema = None
__all__ = 'Table', 'Column',
# can't be imported until the path is setup; be sure to configure
# first if covering.
-from sqlalchemy import *
+
from testlib import testing
-from testlib.schema import Table, Column
+from testlib.sa import MetaData, Table, Column, Integer, String, Sequence, \
+ ForeignKey, VARCHAR, INT
# these are older test fixtures, used primarily by test/orm/mapper.py and
# monkeypatches unittest.TestLoader.suiteClass at import time
-import itertools, os, operator, re, sys, unittest, warnings
+import itertools
+import operator
+import re
+import sys
+import types
+import unittest
+import warnings
from cStringIO import StringIO
+
import testlib.config as config
-from testlib.compat import *
+from testlib.compat import set, _function_named, reversed
-sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None
-sa_exceptions = None
+# Delayed imports
+MetaData = None
+Session = None
+clear_mappers = None
+sa_exc = None
+schema = None
+sqltypes = None
+util = None
-__all__ = ('TestBase', 'AssertsExecutionResults', 'ComparesTables', 'ORMTest', 'AssertsCompiledSQL')
_ops = { '<': operator.lt,
'>': operator.gt,
# sugar ('testing.db'); set here by config() at runtime
db = None
+# more sugar, installed by __init__
+requires = None
+
def fails_if(callable_):
"""Mark a test as expected to fail if callable_ returns True.
# - update: jython looks ok, it uses cpython's module
def decorate(fn):
def safe(*args, **kw):
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
if not messages:
filters = [dict(action='ignore',
- category=sa_exceptions.SAWarning)]
+ category=sa_exc.SAWarning)]
else:
filters = [dict(action='ignore',
message=message,
- category=sa_exceptions.SAWarning)
+ category=sa_exc.SAWarning)
for message in messages ]
for f in filters:
warnings.filterwarnings(**f)
def decorate(fn):
def safe(*args, **kw):
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
if not messages:
filters = [dict(action='ignore',
- category=sa_exceptions.SADeprecationWarning)]
+ category=sa_exc.SADeprecationWarning)]
else:
filters = [dict(action='ignore',
message=message,
- category=sa_exceptions.SADeprecationWarning)
+ category=sa_exc.SADeprecationWarning)
for message in
[ (m.startswith('//') and
('Call to deprecated function ' + m[2:]) or m)
def resetwarnings():
"""Reset warning behavior to testing defaults."""
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
warnings.resetwarnings()
- warnings.filterwarnings('error', category=sa_exceptions.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exceptions.SAWarning)
+ warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SAWarning)
if sys.version_info < (2, 4):
warnings.filterwarnings('ignore', category=FutureWarning)
return set([tuple(row) for row in results])
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+
class TestData(object):
"""Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
can be tracked."""
def __init__(self, ctx):
- global sql
- if sql is None:
- from sqlalchemy import sql
-
self.__dict__['ctx'] = ctx
def __getattr__(self, key):
return getattr(self.ctx, key)
query = self.convert_statement(query)
equivalent = ( (statement == query)
- or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
+ or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
) \
and \
( (params is None) or (params == parameters)
for (k, v) in p.items()])
for p in parameters]
)
- testdata.unittest.assert_(equivalent,
+ testdata.unittest.assert_(equivalent,
"Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
testdata.sql_count += 1
self.ctx.post_execution()
query = re.sub(r':([\w_]+)', repl, query)
return query
+
+def _import_by_name(name):
+ submodule = name.split('.')[-1]
+ return __import__(name, globals(), locals(), [submodule])
+
+class CompositeModule(types.ModuleType):
+ """Merged attribute access for multiple modules."""
+
+ # break the habit
+ __all__ = ()
+
+ def __init__(self, name, *modules, **overrides):
+ """Construct a new lazy composite of modules.
+
+ Modules may be string names or module-like instances. Individual
+ attribute overrides may be specified as keyword arguments for
+ convenience.
+
+ The constructed module will resolve attribute access in reverse order:
+ overrides, then each member of reversed(modules). Modules specified
+ by name will be loaded lazily when encountered in attribute
+ resolution.
+
+ """
+ types.ModuleType.__init__(self, name)
+ self.__modules = list(reversed(modules))
+ for key, value in overrides.iteritems():
+ setattr(self, key, value)
+
+ def __getattr__(self, key):
+ for idx, mod in enumerate(self.__modules):
+ if isinstance(mod, basestring):
+ self.__modules[idx] = mod = _import_by_name(mod)
+ if hasattr(mod, key):
+ return getattr(mod, key)
+ raise AttributeError(key)
+
+
class TestBase(unittest.TestCase):
# A sequence of dialect names to exclude from the test class.
__unsupported_on__ = ()
def shortDescription(self):
"""overridden to not return docstrings"""
return None
-
+
def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
assert False, "Callable did not raise expected exception"
except except_cls, e:
assert re.search(msg, str(e)), "Exception message did not match: '%s'" % str(e)
-
+
if not hasattr(unittest.TestCase, 'assertTrue'):
assertTrue = unittest.TestCase.failUnless
if not hasattr(unittest.TestCase, 'assertFalse'):
set(type(c.type).__mro__).difference(base_mro)
)
) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
-
+
if isinstance(c.type, sqltypes.String):
self.assertEquals(c.type.length, reflected_c.type.length)
elif not c.primary_key or not against('postgres'):
print repr(c)
assert reflected_c.default is None, reflected_c.default
-
+
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name]
-
+
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
print repr(result)
self.assert_list(result, class_, objects)
-
+
def assert_list(self, result, class_, list):
self.assert_(len(result) == len(list),
"result list is not the same size as test list, " +
def define_tables(self, _otest_metadata):
raise NotImplementedError()
-
+
def setup_mappers(self):
pass
-
+
def insert_data(self):
pass