From: Mike Bayer Date: Fri, 9 May 2008 14:07:28 +0000 (+0000) Subject: back-merged current 0.4 trunk into rel_0_4 branch, which will become the 0.4 maintena... X-Git-Tag: rel_0_4_6~2 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=f865f0a88da54707a9b33dfb121025dfab686009;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git back-merged current 0.4 trunk into rel_0_4 branch, which will become the 0.4 maintenance branch --- diff --git a/CHANGES b/CHANGES index ec8d8fcce1..35d53ab618 100644 --- a/CHANGES +++ b/CHANGES @@ -1,241 +1,2256 @@ -0.4.0 +======= +CHANGES +======= + +0.4.6 +===== +- orm + - A fix to the recent relation() refactoring which fixes + exotic viewonly relations which join between local and + remote table multiple times, with a common column shared + between the joins. + + - Also re-established viewonly relation() configurations + that join across multiple tables. + + - contains_eager(), the hot function of the week, suppresses + the eager loader's own generation of the LEFT OUTER JOIN, + so that it is reasonable to use any Query, not just those + which use from_statement(). + + - Added an experimental relation() flag to help with + primaryjoins across functions, etc., + _local_remote_pairs=[tuples]. This complements a complex + primaryjoin condition allowing you to provide the + individual column pairs which comprise the relation's + local and remote sides. Also improved lazy load SQL + generation to handle placing bind params inside of + functions and other expressions. (partial progress + towards [ticket:610]) + + - Fixed "concatenate tuple" bug which could occur with + Query.order_by() if clause adaption had taken place. + [ticket:1027] + + - Removed an ancient assertion that mapped selectables + require "alias names" - the mapper creates its own alias + now if none is present. Though in this case you need to + use the class, not the mapped selectable, as the source of + column attributes - so a warning is still issued. + + - Fixes to the "exists" function involving inheritance + (any(), has(), ~contains()); the full target join will be + rendered into the EXISTS clause for relations that link to + subclasses. + + - Restored usage of append_result() extension method for + primary query rows, when the extension is present and only + a single- entity result is being returned. + + - Fixed Class.collection==None for m2m relationships + [ticket:4213] + + - Refined mapper._save_obj() which was unnecessarily calling + __ne__() on scalar values during flush [ticket:1015] + + - Added a feature to eager loading whereby subqueries set as + column_property() with explicit label names (which is not + necessary, btw) will have the label anonymized when the + instance is part of the eager join, to prevent conflicts + with a subquery or column of the same name on the parent + object. [ticket:1019] + + - Same as [ticket:1019] but repaired the non-labeled use + case [ticket:1022] + + - Adjusted class-member inspection during attribute and + collection instrumentation that could be problematic when + integrating with other frameworks. + + - Fixed duplicate append event emission on repeated + instrumented set.add() operations. + + - set-based collections |=, -=, ^= and &= are stricter about + their operands and only operate on sets, frozensets or + subclasses of the collection type. Previously, they would + accept any duck-typed set. + + - added an example dynamic_dict/dynamic_dict.py, illustrating + a simple way to place dictionary behavior on top of + a dynamic_loader. + +- sql + - Added COLLATE support via the .collate() + expression operator and collate(, ) sql + function. + + - Fixed bug with union() when applied to non-Table connected + select statements + + - Improved behavior of text() expressions when used as FROM + clauses, such as select().select_from(text("sometext")) + [ticket:1014] + + - Column.copy() respects the value of "autoincrement", fixes + usage with Migrate [ticket:1021] + +- engines + - Pool listeners can now be provided as a dictionary of + callables or a (possibly partial) duck-type of + PoolListener, your choice. + + - Added "reset_on_return" option to Pool which will disable + the database state cleanup step (e.g. issuing a + rollback()) when connections are returned to the pool. + +-extensions + - set-based association proxies |=, -=, ^= and &= are + stricter about their operands and only operate on sets, + frozensets or other association proxies. Previously, they + would accept any duck-typed set. + +- declarative extension + - Joined table inheritance mappers use a slightly relaxed + function to create the "inherit condition" to the parent + table, so that other foreign keys to not-yet-declared + Table objects don't trigger an error. + + - Fixed re-entrant mapper compile hang when a declared + attribute is used within ForeignKey, + i.e. ForeignKey(MyOtherClass.someattribute) + +- mssql + - Added "odbc_autotranslate" parameter to engine / dburi + parameters. Any given string will be passed through to the + ODBC connection string as: + + "AutoTranslate=%s" % odbc_autotranslate + + [ticket:1005] + + - Added "odbc_options" parameter to engine / dburi + parameters. The given string is simply appended to the + SQLAlchemy-generated odbc connection string. + + This should obviate the need of adding a myriad of ODBC + options in the future. + +- firebird + - Handle the "SUBSTRING(:string FROM :start FOR :length)" + builtin. + + +0.4.5 +===== - orm + - A small change in behavior to session.merge() - existing + objects are checked for based on primary key attributes, not + necessarily _instance_key. So the widely requested + capability, that: + + x = MyObject(id=1) + x = sess.merge(x) + + will in fact load MyObject with id #1 from the database if + present, is now available. merge() still copies the state + of the given object to the persistent one, so an example + like the above would typically have copied "None" from all + attributes of "x" onto the persistent copy. These can be + reverted using session.expire(x). + + - Also fixed behavior in merge() whereby collection elements + present on the destination but not the merged collection + were not being removed from the destination. + + - Added a more aggressive check for "uncompiled mappers", + helps particularly with declarative layer [ticket:995] + + - The methodology behind "primaryjoin"/"secondaryjoin" has + been refactored. Behavior should be slightly more + intelligent, primarily in terms of error messages which + have been pared down to be more readable. In a slight + number of scenarios it can better resolve the correct + foreign key than before. + + - Added comparable_property(), adds query Comparator + behavior to regular, unmanaged Python properties + + - The functionality of query.with_polymorphic() has been + added to mapper() as a configuration option. + + It's set via several forms: + + with_polymorphic='*' + with_polymorphic=[mappers] + with_polymorphic=('*', selectable) + with_polymorphic=([mappers], selectable) + + This controls the default polymorphic loading strategy for + inherited mappers. When a selectable is not given, outer + joins are created for all joined-table inheriting mappers + requested. Note that the auto-create of joins is not + compatible with concrete table inheritance. + + The existing select_table flag on mapper() is now + deprecated and is synonymous with: + + with_polymorphic('*', select_table). + + Note that the underlying "guts" of select_table have been + completely removed and replaced with the newer, more + flexible approach. + + The new approach also automatically allows eager loads to + work for subclasses, if they are present, for example + + sess.query(Company).options( + eagerload_all( + [Company.employees.of_type(Engineer), 'machines'] + )) + + to load Company objects, their employees, and the + 'machines' collection of employees who happen to be + Engineers. A "with_polymorphic" Query option should be + introduced soon as well which would allow per-Query + control of with_polymorphic() on relations. + + - Added two "experimental" features to Query, "experimental" + in that their specific name/behavior is not carved in + stone just yet: _values() and _from_self(). We'd like + feedback on these. + + - _values(*columns) is given a list of column expressions, + and returns a new Query that only returns those + columns. When evaluated, the return value is a list of + tuples just like when using add_column() or + add_entity(), the only difference is that "entity zero", + i.e. the mapped class, is not included in the + results. This means it finally makes sense to use + group_by() and having() on Query, which have been + sitting around uselessly until now. + + A future change to this method may include that its + ability to join, filter and allow other options not + related to a "resultset" are removed, so the feedback + we're looking for is how people want to use + _values()...i.e. at the very end, or do people prefer to + continue generating after it's called. + + - _from_self() compiles the SELECT statement for the Query + (minus any eager loaders), and returns a new Query that + selects from that SELECT. So basically you can query + from a Query without needing to extract the SELECT + statement manually. This gives meaning to operations + like query[3:5]._from_self().filter(some + criterion). There's not much controversial here except + that you can quickly create highly nested queries that + are less efficient, and we want feedback on the naming + choice. + + - query.order_by() and query.group_by() will accept multiple + arguments using *args (like select() already does). + + - Added some convenience descriptors to Query: + query.statement returns the full SELECT construct, + query.whereclause returns just the WHERE part of the + SELECT construct. + + - Fixed/covered case when using a False/0 value as a + polymorphic discriminator. + + - Fixed bug which was preventing synonym() attributes from + being used with inheritance + + - Fixed SQL function truncation of trailing underscores + [ticket:996] + + - When attributes are expired on a pending instance, an + error will not be raised when the "refresh" action is + triggered and no result is found. + + - Session.execute can now find binds from metadata + + - Adjusted the definition of "self-referential" to be any + two mappers with a common parent (this affects whether or + not aliased=True is required when joining with Query). + + - Made some fixes to the "from_joinpoint" argument to + query.join() so that if the previous join was aliased and + this one isn't, the join still happens successfully. + + - Assorted "cascade deletes" fixes: + - Fixed "cascade delete" operation of dynamic relations, + which had only been implemented for foreign-key + nulling behavior in 0.4.2 and not actual cascading + deletes [ticket:895] + + - Delete cascade without delete-orphan cascade on a + many-to-one will not delete orphans which were + disconnected from the parent before session.delete() + is called on the parent (one-to-many already had + this). + + - Delete cascade with delete-orphan will delete orphans + whether or not it remains attached to its also-deleted + parent. + + - delete-orphan casacde is properly detected on + relations that are present on superclasses when using + inheritance. + + - Fixed order_by calculation in Query to properly alias + mapper-config'ed order_by when using select_from() + + - Refactored the diffing logic that kicks in when replacing + one collection with another into collections.bulk_replace, + useful to anyone building multi-level collections. + + - Cascade traversal algorithm converted from recursive to + iterative to support deep object graphs. + +- sql + - Schema-qualified tables now will place the schemaname + ahead of the tablename in all column expressions as well + as when generating column labels. This prevents cross- + schema name collisions in all cases [ticket:999] + + - Can now allow selects which correlate all FROM clauses and + have no FROM themselves. These are typically used in a + scalar context, i.e. SELECT x, (SELECT x WHERE y) FROM + table. Requires explicit correlate() call. + + - 'name' is no longer a required constructor argument for + Column(). It (and .key) may now be deferred until the + column is added to a Table. + + - like(), ilike(), contains(), startswith(), endswith() take + an optional keyword argument "escape=", which + is set as the escape character using the syntax "x LIKE y + ESCAPE ''" [ticket:993], [ticket:791] + + - random() is now a generic sql function and will compile to + the database's random implementation, if any. + + - update().values() and insert().values() take keyword + arguments. + + - Fixed an issue in select() regarding its generation of + FROM clauses, in rare circumstances two clauses could be + produced when one was intended to cancel out the other. + Some ORM queries with lots of eager loads might have seen + this symptom. + + - The case() function now also takes a dictionary as its + whens parameter. It also interprets the "THEN" + expressions as values by default, meaning case([(x==y, + "foo")]) will interpret "foo" as a bound value, not a SQL + expression. use text(expr) for literal SQL expressions in + this case. For the criterion itself, these may be literal + strings only if the "value" keyword is present, otherwise + SA will force explicit usage of either text() or + literal(). + +- declarative extension + - The "synonym" function is now directly usable with + "declarative". Pass in the decorated property using the + "descriptor" keyword argument, e.g.: somekey = + synonym('_somekey', descriptor=property(g, s)) + + - The "deferred" function is usable with "declarative". + Simplest usage is to declare deferred and Column together, + e.g.: data = deferred(Column(Text)) + + - Declarative also gained @synonym_for(...) and + @comparable_using(...), front-ends for synonym and + comparable_property. + + - Improvements to mapper compilation when using declarative; + already-compiled mappers will still trigger compiles of + other uncompiled mappers when used [ticket:995] + + - Declarative will complete setup for Columns lacking names, + allows a more DRY syntax. + + class Foo(Base): + __tablename__ = 'foos' + id = Column(Integer, primary_key=True) + + - inheritance in declarative can be disabled when sending + "inherits=None" to __mapper_args__. + + - declarative_base() takes optional kwarg "mapper", which + is any callable/class/method that produces a mapper, such + as declarative_base(mapper=scopedsession.mapper). This + property can also be set on individual declarative + classes using the "__mapper_cls__" property. + +- postgres + - Got PG server side cursors back into shape, added fixed + unit tests as part of the default test suite. Added + better uniqueness to the cursor ID [ticket:1001] + +- oracle + - The "owner" keyword on Table is now deprecated, and is + exactly synonymous with the "schema" keyword. Tables can + now be reflected with alternate "owner" attributes, + explicitly stated on the Table object or not using + "schema". + + - All of the "magic" searching for synonyms, DBLINKs etc. + during table reflection are disabled by default unless you + specify "oracle_resolve_synonyms=True" on the Table + object. Resolving synonyms necessarily leads to some + messy guessing which we'd rather leave off by default. + When the flag is set, tables and related tables will be + resolved against synonyms in all cases, meaning if a + synonym exists for a particular table, reflection will use + it when reflecting related tables. This is stickier + behavior than before which is why it's off by default. + +- mssql + - Reflected tables will now automatically load other tables + which are referenced by Foreign keys in the auto-loaded + table, [ticket:979]. + + - Added executemany check to skip identity fetch, + [ticket:916]. + + - Added stubs for small date type, [ticket:884] + + - Added a new 'driver' keyword parameter for the pyodbc + dialect. Will substitute into the ODBC connection string + if given, defaults to 'SQL Server'. + + - Added a new 'max_identifier_length' keyword parameter for + the pyodbc dialect. + + - Improvements to pyodbc + Unix. If you couldn't get that + combination to work before, please try again. + +- mysql + - The connection.info keys the dialect uses to cache server + settings have changed and are now namespaced. + +0.4.4 +------ +- sql + - Can again create aliases of selects against textual FROM + clauses, [ticket:975] + + - The value of a bindparam() can be a callable, in which + case it's evaluated at statement execution time to get the + value. + + - Added exception wrapping/reconnect support to result set + fetching. Reconnect works for those databases that raise + a catchable data error during results (i.e. doesn't work + on MySQL) [ticket:978] + + - Implemented two-phase API for "threadlocal" engine, via + engine.begin_twophase(), engine.prepare() [ticket:936] + + - Fixed bug which was preventing UNIONS from being + cloneable, [ticket:986] - - new collection_class api and implementation [ticket:213] - collections are now instrumented via decorations rather than - proxying. you can now have collections that manage their own - membership, and your class instance will be directly exposed on the - relation property. the changes are transparent for most users. - - InstrumentedList (as it was) is removed, and relation properties - no longer have 'clear()', '.data', or any other added methods - beyond those provided by the collection type. you are free, of - course, to add them to a custom class. - - __setitem__-like assignments now fire remove events for the - existing value, if any. - - dict-likes used as collection classes no longer need to change - __iter__ semantics- itervalues() is used by default instead. this - is a backwards incompatible change. - - subclassing dict for a mapped collection is no longer needed in - most cases. orm.collections provides canned implementations that - key objects by a specified column or a custom function of your - choice. - - collection assignment now requires a compatible type- assigning - None to clear a collection or assigning a list to a dict - collection will now raise an argument error. - - AttributeExtension moved to interfaces, and .delete is now - .remove The event method signature has also been swapped around. - - - major overhaul for Query: all selectXXX methods - are deprecated. generative methods are now the standard - way to do things, i.e. filter(), filter_by(), all(), one(), - etc. Deprecated methods are docstring'ed with their - new replacements. - - - Class-level properties are now usable as query elements ...no - more '.c.' ! "Class.c.propname" is now superceded by "Class.propname". - All clause operators are supported, as well as higher level operators - such as Class.prop== for scalar attributes, - Class.prop.contains() and Class.prop.any() - for collection-based attributes (all are also negatable). Table-based column - expressions as well as columns mounted on mapped classes via 'c' are of - course still fully available and can be freely mixed with the new attributes. - [ticket:643] + - Added "bind" keyword argument to insert(), update(), + delete() and DDL(). The .bind property is now assignable + on those statements as well as on select(). + + - Insert statements can now be compiled with extra "prefix" + words between INSERT and INTO, for vendor extensions like + MySQL's INSERT IGNORE INTO table. + +- orm + - any(), has(), contains(), ~contains(), attribute level == + and != now work properly with self-referential relations - + the clause inside the EXISTS is aliased on the "remote" + side to distinguish it from the parent table. This + applies to single table self-referential as well as + inheritance-based self-referential. + + - Repaired behavior of == and != operators at the relation() + level when compared against NULL for one-to-one relations + [ticket:985] + + - Fixed bug whereby session.expire() attributes were not + loading on an polymorphically-mapped instance mapped by a + select_table mapper. + + - Added query.with_polymorphic() - specifies a list of + classes which descend from the base class, which will be + added to the FROM clause of the query. Allows subclasses + to be used within filter() criterion as well as eagerly + loads the attributes of those subclasses. + + - Your cries have been heard: removing a pending item from + an attribute or collection with delete-orphan expunges the + item from the session; no FlushError is raised. Note that + if you session.save()'ed the pending item explicitly, the + attribute/collection removal still knocks it out. + + - session.refresh() and session.expire() raise an error when + called on instances which are not persistent within the + session + + - Fixed potential generative bug when the same Query was + used to generate multiple Query objects using join(). + + - Fixed bug which was introduced in 0.4.3, whereby loading + an already-persistent instance mapped with joined table + inheritance would trigger a useless "secondary" load from + its joined table, when using the default "select" + polymorphic_fetch. This was due to attributes being + marked as expired during its first load and not getting + unmarked from the previous "secondary" load. Attributes + are now unexpired based on presence in __dict__ after any + load or commit operation succeeds. + + - Deprecated Query methods apply_sum(), apply_max(), + apply_min(), apply_avg(). Better methodologies are + coming.... + + - relation() can accept a callable for its first argument, + which returns the class to be related. This is in place + to assist declarative packages to define relations without + classes yet being in place. + + - Added a new "higher level" operator called "of_type()": + used in join() as well as with any() and has(), qualifies + the subclass which will be used in filter criterion, e.g.: + + query.filter(Company.employees.of_type(Engineer). + any(Engineer.name=='foo')) + + or + + query.join(Company.employees.of_type(Engineer)). + filter(Engineer.name=='foo') + + - Preventive code against a potential lost-reference bug in + flush(). + + - Expressions used in filter(), filter_by() and others, when + they make usage of a clause generated from a relation + using the identity of a child object (e.g., + filter(Parent.child==)), evaluate the actual + primary key value of at execution time so that + the autoflush step of the Query can complete, thereby + populating the PK value of in the case that + was pending. + + - setting the relation()-level order by to a column in the + many-to-many "secondary" table will now work with eager + loading, previously the "order by" wasn't aliased against + the secondary table's alias. + + - Synonyms riding on top of existing descriptors are now + full proxies to those descriptors. + +- dialects + - Invalid SQLite connection URLs now raise an error. + + - postgres TIMESTAMP renders correctly [ticket:981] + + - postgres PGArray is a "mutable" type by default; when used + with the ORM, mutable-style equality/ copy-on-write + techniques are used to test for changes. + +- extensions + - a new super-small "declarative" extension has been added, + which allows Table and mapper() configuration to take + place inline underneath a class declaration. This + extension differs from ActiveMapper and Elixir in that it + does not redefine any SQLAlchemy semantics at all; literal + Column, Table and relation() constructs are used to define + the class behavior and table definition. + +0.4.3 +------ +- sql + - Added "schema.DDL", an executable free-form DDL statement. + DDLs can be executed in isolation or attached to Table or + MetaData instances and executed automatically when those + objects are created and/or dropped. + + - Table columns and constraints can be overridden on a an + existing table (such as a table that was already reflected) + using the 'useexisting=True' flag, which now takes into + account the arguments passed along with it. + + - Added a callable-based DDL events interface, adds hooks + before and after Tables and MetaData create and drop. + + - Added generative where() method to delete() and + update() constructs which return a new object with criterion + joined to existing criterion via AND, just like + select().where(). + + - Added "ilike()" operator to column operations. Compiles to + ILIKE on postgres, lower(x) LIKE lower(y) on all + others. [ticket:727] + + - Added "now()" as a generic function; on SQLite, Oracle + and MSSQL compiles as "CURRENT_TIMESTAMP"; "now()" on + all others. [ticket:943] + + - The startswith(), endswith(), and contains() operators now + concatenate the wildcard operator with the given operand in + SQL, i.e. "'%' || " in all cases, accept + text('something') operands properly [ticket:962] + + - cast() accepts text('something') and other non-literal + operands properly [ticket:962] + + - fixed bug in result proxy where anonymously generated + column labels would not be accessible using their straight + string name + + - Deferrable constraints can now be defined. + + - Added "autocommit=True" keyword argument to select() and + text(), as well as generative autocommit() method on + select(); for statements which modify the database through + some user-defined means other than the usual INSERT/UPDATE/ + DELETE etc. This flag will enable "autocommit" behavior + during execution if no transaction is in progress. + [ticket:915] + + - The '.c.' attribute on a selectable now gets an entry for + every column expression in its columns clause. Previously, + "unnamed" columns like functions and CASE statements weren't + getting put there. Now they will, using their full string + representation if no 'name' is available. + + - a CompositeSelect, i.e. any union(), union_all(), + intersect(), etc. now asserts that each selectable contains + the same number of columns. This conforms to the + corresponding SQL requirement. + + - The anonymous 'label' generated for otherwise unlabeled + functions and expressions now propagates outwards at compile + time for expressions like select([select([func.foo()])]). + + - Building on the above ideas, CompositeSelects now build up + their ".c." collection based on the names present in the + first selectable only; corresponding_column() now works + fully for all embedded selectables. + + - Oracle and others properly encode SQL used for defaults like + sequences, etc., even if no unicode idents are used since + identifier preparer may return a cached unicode identifier. + + - Column and clause comparisons to datetime objects on the + left hand side of the expression now work (d < table.c.col). + (datetimes on the RHS have always worked, the LHS exception + is a quirk of the datetime implementation.) + +- orm + - Every Session.begin() must now be accompanied by a + corresponding commit() or rollback() unless the session is + closed with Session.close(). This also includes the begin() + which is implicit to a session created with + transactional=True. The biggest change introduced here is + that when a Session created with transactional=True raises + an exception during flush(), you must call + Session.rollback() or Session.close() in order for that + Session to continue after an exception. + + - Fixed merge() collection-doubling bug when merging transient + entities with backref'ed collections. [ticket:961] + + - merge(dont_load=True) does not accept transient entities, + this is in continuation with the fact that + merge(dont_load=True) does not accept any "dirty" objects + either. + + - Added standalone "query" class attribute generated by a + scoped_session. This provides MyClass.query without using + Session.mapper. Use via: + + MyClass.query = Session.query_property() + + - The proper error message is raised when trying to access + expired instance attributes with no session present + + - dynamic_loader() / lazy="dynamic" now accepts and uses + the order_by parameter in the same way in which it works + with relation(). - - removed ancient query.select_by_attributename() capability. - - - the aliasing logic used by eager loading has been generalized, so that - it also adds full automatic aliasing support to Query. It's no longer - necessary to create an explicit Alias to join to the same tables multiple times; - *even for self-referential relationships!!* - - join() and outerjoin() take arguments "aliased=True". this causes - their joins to be built on aliased tables; subsequent calls - to filter() and filter_by() will translate all table expressions - (yes, real expressions using the original mapped Table) to be that of - the Alias for the duration of that join() (i.e. until reset_joinpoint() - or another join() is called). - - join() and outerjoin() take arguments "id=". when used - with "aliased=True", the id can be referenced by add_entity(cls, id=) - so that you can select the joined instances even if they're from an alias. - - join() and outerjoin() now work with self-referential relationships! using - "aliased=True", you can join as many levels deep as desired, i.e. - query.join(['children', 'children'], aliased=True); filter criterion will - be against the rightmost joined table - - - added query.populate_existing() - marks the query to reload - all attributes and collections of all instances touched in the query, - including eagerly-loaded entities [ticket:660] - - - added eagerload_all(), allows eagerload_all('x.y.z') to specify eager - loading of all properties in the given path + - Added expire_all() method to Session. Calls expire() for + all persistent instances. This is handy in conjunction + with... + + - Instances which have been partially or fully expired will + have their expired attributes populated during a regular + Query operation which affects those objects, preventing a + needless second SQL statement for each instance. + + - Dynamic relations, when referenced, create a strong + reference to the parent object so that the query still has a + parent to call against even if the parent is only created + (and otherwise dereferenced) within the scope of a single + expression. [ticket:938] + + - Added a mapper() flag "eager_defaults". When set to True, + defaults that are generated during an INSERT or UPDATE + operation are post-fetched immediately, instead of being + deferred until later. This mimics the old 0.3 behavior. + + - query.join() can now accept class-mapped attributes as + arguments. These can be used in place or in any combination + with strings. In particular this allows construction of + joins to subclasses on a polymorphic relation, i.e.: + + query(Company).join(['employees', Engineer.name]) + + - query.join() can also accept tuples of attribute name/some + selectable as arguments. This allows construction of joins + *from* subclasses of a polymorphic relation, i.e.: + + query(Company).\ + join( + [('employees', people.join(engineer)), Engineer.name] + ) + + - General improvements to the behavior of join() in + conjunction with polymorphic mappers, i.e. joining from/to + polymorphic mappers and properly applying aliases. + + - Fixed/improved behavior when a mapper determines the natural + "primary key" of a mapped join, it will more effectively + reduce columns which are equivalent via foreign key + relation. This affects how many arguments need to be sent + to query.get(), among other things. [ticket:933] + + - The lazy loader can now handle a join condition where the + "bound" column (i.e. the one that gets the parent id sent as + a bind parameter) appears more than once in the join + condition. Specifically this allows the common task of a + relation() which contains a parent-correlated subquery, such + as "select only the most recent child item". [ticket:946] + + - Fixed bug in polymorphic inheritance where an incorrect + exception is raised when base polymorphic_on column does not + correspond to any columns within the local selectable of an + inheriting mapper more than one level deep + + - Fixed bug in polymorphic inheritance which made it difficult + to set a working "order_by" on a polymorphic mapper. + + - Fixed a rather expensive call in Query that was slowing down + polymorphic queries. + + - "Passive defaults" and other "inline" defaults can now be + loaded during a flush() call if needed; in particular, this + allows constructing relations() where a foreign key column + references a server-side-generated, non-primary-key + column. [ticket:954] + + - Additional Session transaction fixes/changes: + - Fixed bug with session transaction management: parent + transactions weren't started on the connection when + adding a connection to a nested transaction. + + - session.transaction now always refers to the innermost + active transaction, even when commit/rollback are called + directly on the session transaction object. + + - Two-phase transactions can now be prepared. + + - When preparing a two-phase transaction fails on one + connection, all the connections are rolled back. + + - session.close() didn't close all transactions when + nested transactions were used. + + - rollback() previously erroneously set the current + transaction directly to the parent of the transaction + that could be rolled back to. Now it rolls back the next + transaction up that can handle it, but sets the current + transaction to it's parent and inactivates the + transactions in between. Inactive transactions can only + be rolled back or closed, any other call results in an + error. + + - autoflush for commit() wasn't flushing for simple + subtransactions. + + - unitofwork flush didn't close the failed transaction + when the session was not in a transaction and commiting + the transaction failed. + + - Miscellaneous tickets: [ticket:940] [ticket:964] + +- general + - Fixed a variety of hidden and some not-so-hidden + compatibility issues for Python 2.3, thanks to new support + for running the full test suite on 2.3. + + - Warnings are now issued as type exceptions.SAWarning. + +- dialects + - Better support for schemas in SQLite (linked in by ATTACH + DATABASE ... AS name). In some cases in the past, schema + names were ommitted from generated SQL for SQLite. This is + no longer the case. + + - table_names on SQLite now picks up temporary tables as well. + + - Auto-detect an unspecified MySQL ANSI_QUOTES mode during + reflection operations, support for changing the mode + midstream. Manual mode setting is still required if no + reflection is used. + + - Fixed reflection of TIME columns on SQLite. + + - Finally added PGMacAddr type to postgres [ticket:580] + + - Reflect the sequence associated to a PK field (typically + with a BEFORE INSERT trigger) under Firebird + + - Oracle assembles the correct columns in the result set + column mapping when generating a LIMIT/OFFSET subquery, + allows columns to map properly to result sets even if + long-name truncation kicks in [ticket:941] + + - MSSQL now includes EXEC in the _is_select regexp, which + should allow row-returning stored procedures to be used. + + - MSSQL now includes an experimental implementation of + LIMIT/OFFSET using the ANSI SQL row_number() function, so it + requires MSSQL-2005 or higher. To enable the feature, add + "has_window_funcs" to the keyword arguments for connect, or + add "?has_window_funcs=1" to your dburi query arguments. + +- ext + - Changed ext.activemapper to use a non-transactional session + for the objectstore. + + - Fixed output order of "['a'] + obj.proxied" binary operation + on association-proxied lists. + +0.4.2p3 +------ +- general + - sub version numbering scheme changed to suite + setuptools version number rules; easy_install -u + should now get this version over 0.4.2. + +- sql + - Text type is properly exported now and does not + raise a warning on DDL create; String types with no + length only raise warnings during CREATE TABLE + [ticket:912] + + - new UnicodeText type is added, to specify an + encoded, unlengthed Text type + + - fixed bug in union() so that select() statements + which don't derive from FromClause objects can be + unioned + +- orm + - fixed bug with session.dirty when using "mutable + scalars" (such as PickleTypes) + + - added a more descriptive error message when flushing + on a relation() that has non-locally-mapped columns + in its primary or secondary join condition + +- dialects + - Fixed reflection of mysql empty string column + defaults. + +0.4.2b (0.4.2p2) +------ +- sql + - changed name of TEXT to Text since its a "generic" + type; TEXT name is deprecated until 0.5. The + "upgrading" behavior of String to Text when no + length is present is also deprecated until 0.5; will + issue a warning when used for CREATE TABLE + statements (String with no length for SQL expression + purposes is still fine) [ticket:912] + + - generative select.order_by(None) / group_by(None) + was not managing to reset order by/group by + criterion, fixed [ticket:924] + +- orm + - suppressing *all* errors in + InstanceState.__cleanup() now. + + - fixed an attribute history bug whereby assigning a + new collection to a collection-based attribute which + already had pending changes would generate incorrect + history [ticket:922] + + - fixed delete-orphan cascade bug whereby setting the + same object twice to a scalar attribute could log it + as an orphan [ticket:925] + + - Fixed cascades on a += assignment to a list-based + relation. - - a rudimental sharding (horizontal scaling) system is introduced. This system - uses a modified Session which can distribute read and write operations among - multiple databases, based on user-defined functions defining the - "sharding strategy". Instances and their dependents can be distributed - and queried among multiple databases based on attribute values, round-robin - approaches or any other user-defined system. [ticket:618] - - - Eager loading has been enhanced to allow even more joins in more places. - It now functions at any arbitrary depth along self-referential - and cyclical structures. When loading cyclical structures, specify "join_depth" - on relation() indicating how many times you'd like the table to join - to itself; each level gets a distinct table alias. The alias names - themselves are generated at compile time using a simple counting - scheme now and are a lot easier on the eyes, as well as of course - completely deterministic. [ticket:659] + - synonyms can now be created against props that don't + exist yet, which are later added via add_property(). + This commonly includes backrefs. (i.e. you can make + synonyms for backrefs without worrying about the + order of operations) [ticket:919] + + - fixed bug which could occur with polymorphic "union" + mapper which falls back to "deferred" loading of + inheriting tables - - added composite column properties. This allows you to create a - type which is represented by more than one column, when using the - ORM. Objects of the new type are fully functional in query expressions, - comparisons, query.get() clauses, etc. and act as though they are regular - single-column scalars..except they're not ! - Use the function composite(cls, *columns) inside of the - mapper's "properties" dict, and instances of cls will be - created/mapped to a single attribute, comprised of the values - correponding to *columns [ticket:211] - - - improved support for custom column_property() attributes which - feature correlated subqueries...work better with eager loading now. - - - along with recent speedups to ResultProxy, total number of - function calls significantly reduced for large loads. - test/perf/masseagerload.py reports 0.4 as having the fewest number - of function calls across all SA versions (0.1, 0.2, and 0.3) - - - primary key "collapse" behavior; the mapper will analyze all columns - in its given selectable for primary key "equivalence", that is, - columns which are equivalent via foreign key relationship or via an - explicit inherit_condition. primarily for joined-table inheritance - scenarios where different named PK columns in inheriting tables - should "collapse" into a single-valued (or fewer-valued) primary key. - fixes things like [ticket:611]. - - - joined-table inheritance will now generate the primary key - columns of all inherited classes against the root table of the - join only. This implies that each row in the root table is distinct - to a single instance. If for some rare reason this is not desireable, - explicit primary_key settings on individual mappers will override it. + - the "columns" collection on a mapper/mapped class + (i.e. 'c') is against the mapped table, not the + select_table in the case of polymorphic "union" + loading (this shouldn't be noticeable). - - When "polymorphic" flags are used with joined-table or single-table - inheritance, all identity keys are generated against the root class - of the inheritance hierarchy; this allows query.get() to work - polymorphically using the same caching semantics as a non-polymorphic get. - note that this currently does not work with concrete inheritance. +- ext + - '+', '*', '+=' and '*=' support for association + proxied lists. + +- dialects + - mssql - narrowed down the test for "date"/"datetime" + in MSDate/ MSDateTime subclasses so that incoming + "datetime" objects don't get mis-interpreted as + "date" objects and vice versa, [ticket:923] + +0.4.2a (0.4.2p1) +------ + +- orm + - fixed fairly critical bug whereby the same instance could be listed + more than once in the unitofwork.new collection; most typically + reproduced when using a combination of inheriting mappers and + ScopedSession.mapper, as the multiple __init__ calls per instance + could save() the object with distinct _state objects + + - added very rudimentary yielding iterator behavior to Query. Call + query.yield_per() and evaluate the Query in an + iterative context; every collection of N rows will be packaged up + and yielded. Use this method with extreme caution since it does + not attempt to reconcile eagerly loaded collections across + result batch boundaries, nor will it behave nicely if the same + instance occurs in more than one batch. This means that an eagerly + loaded collection will get cleared out if it's referenced in more than + one batch, and in all cases attributes will be overwritten on instances + that occur in more than one batch. + + - Fixed in-place set mutation operators for set collections and association + proxied sets. [ticket:920] + +- dialects + - Fixed the missing call to subtype result processor for the PGArray + type. [ticket:913] + +0.4.2 +----- +- sql + - generic functions ! we introduce a database of known SQL functions, such + as current_timestamp, coalesce, and create explicit function objects + representing them. These objects have constrained argument lists, are + type aware, and can compile in a dialect-specific fashion. So saying + func.char_length("foo", "bar") raises an error (too many args), + func.coalesce(datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)) + knows that its return type is a Date. We only have a few functions + represented so far but will continue to add to the system [ticket:615] + + - auto-reconnect support improved; a Connection can now automatically + reconnect after its underlying connection is invalidated, without + needing to connect() again from the engine. This allows an ORM session + bound to a single Connection to not need a reconnect. + Open transactions on the Connection must be rolled back after an invalidation + of the underlying connection else an error is raised. Also fixed + bug where disconnect detect was not being called for cursor(), rollback(), + or commit(). - - secondary inheritance loading: polymorphic mappers can be - constructed *without* a select_table argument. inheriting mappers - whose tables were not represented in the initial load will issue a - second SQL query immediately, once per instance (i.e. not very - efficient for large lists), in order to load the remaining - columns. - - secondary inheritance loading can also move its second query into - a column- level "deferred" load, via the "polymorphic_fetch" - argument, which can be set to 'select' or 'deferred' - - - added undefer_group() MapperOption, sets a set of "deferred" - columns joined by a "group" to load as "undeferred". + - added new flag to String and create_engine(), + assert_unicode=(True|False|'warn'|None). Defaults to `False` or `None` on + create_engine() and String, `'warn'` on the Unicode type. When `True`, + results in all unicode conversion operations raising an exception when a + non-unicode bytestring is passed as a bind parameter. 'warn' results + in a warning. It is strongly advised that all unicode-aware applications + make proper use of Python unicode objects (i.e. u'hello' and not 'hello') + so that data round trips accurately. + + - generation of "unique" bind parameters has been simplified to use the same + "unique identifier" mechanisms as everything else. This doesn't affect + user code, except any code that might have been hardcoded against the generated + names. Generated bind params now have the form "_", + whereas before only the second bind of the same name would have this form. + + - select().as_scalar() will raise an exception if the select does not have + exactly one expression in its columns clause. - - session enhancements/fixes: - - session can be bound to Connections + - bindparam() objects themselves can be used as keys for execute(), i.e. + statement.execute({bind1:'foo', bind2:'bar'}) - - rewrite of the "deterministic alias name" logic to be part of the - SQL layer, produces much simpler alias and label names more in the - style of Hibernate + - added new methods to TypeDecorator, process_bind_param() and + process_result_value(), which automatically take advantage of the processing + of the underlying type. Ideal for using with Unicode or Pickletype. + TypeDecorator should now be the primary way to augment the behavior of any + existing type including other TypeDecorator subclasses such as PickleType. + + - selectables (and others) will issue a warning when two columns in + their exported columns collection conflict based on name. + + - tables with schemas can still be used in sqlite, firebird, + schema name just gets dropped [ticket:890] + + - changed the various "literal" generation functions to use an anonymous + bind parameter. not much changes here except their labels now look + like ":param_1", ":param_2" instead of ":literal" + + - column labels in the form "tablename.columname", i.e. with a dot, are now + supported. + + - from_obj keyword argument to select() can be a scalar or a list. +- orm + - a major behavioral change to collection-based backrefs: they no + longer trigger lazy loads ! "reverse" adds and removes + are queued up and are merged with the collection when it is + actually read from and loaded; but do not trigger a load beforehand. + For users who have noticed this behavior, this should be much more + convenient than using dynamic relations in some cases; for those who + have not, you might notice your apps using a lot fewer queries than + before in some situations. [ticket:871] + + - mutable primary key support is added. primary key columns can be + changed freely, and the identity of the instance will change upon + flush. In addition, update cascades of foreign key referents (primary + key or not) along relations are supported, either in tandem with the + database's ON UPDATE CASCADE (required for DB's like Postgres) or + issued directly by the ORM in the form of UPDATE statements, by setting + the flag "passive_cascades=False". + + - inheriting mappers now inherit the MapperExtensions of their parent + mapper directly, so that all methods for a particular MapperExtension + are called for subclasses as well. As always, any MapperExtension + can return either EXT_CONTINUE to continue extension processing + or EXT_STOP to stop processing. The order of mapper resolution is: + . + + Note that if you instantiate the same extension class separately + and then apply it individually for two mappers in the same inheritance + chain, the extension will be applied twice to the inheriting class, + and each method will be called twice. + + To apply a mapper extension explicitly to each inheriting class but + have each method called only once per operation, use the same + instance of the extension for both mappers. + [ticket:490] + + - MapperExtension.before_update() and after_update() are now called + symmetrically; previously, an instance that had no modified column + attributes (but had a relation() modification) could be called with + before_update() but not after_update() [ticket:907] + + - columns which are missing from a Query's select statement + now get automatically deferred during load. + + - mapped classes which extend "object" and do not provide an + __init__() method will now raise TypeError if non-empty *args + or **kwargs are present at instance construction time (and are + not consumed by any extensions such as the scoped_session mapper), + consistent with the behavior of normal Python classes [ticket:908] + + - fixed Query bug when filter_by() compares a relation against None + [ticket:899] + + - improved support for pickling of mapped entities. Per-instance + lazy/deferred/expired callables are now serializable so that + they serialize and deserialize with _state. + + - new synonym() behavior: an attribute will be placed on the mapped + class, if one does not exist already, in all cases. if a property + already exists on the class, the synonym will decorate the property + with the appropriate comparison operators so that it can be used in in + column expressions just like any other mapped attribute (i.e. usable in + filter(), etc.) the "proxy=True" flag is deprecated and no longer means + anything. Additionally, the flag "map_column=True" will automatically + generate a ColumnProperty corresponding to the name of the synonym, + i.e.: 'somename':synonym('_somename', map_column=True) will map the + column named 'somename' to the attribute '_somename'. See the example + in the mapper docs. [ticket:801] + + - Query.select_from() now replaces all existing FROM criterion with + the given argument; the previous behavior of constructing a list + of FROM clauses was generally not useful as is required + filter() calls to create join criterion, and new tables introduced + within filter() already add themselves to the FROM clause. The + new behavior allows not just joins from the main table, but select + statements as well. Filter criterion, order bys, eager load + clauses will be "aliased" against the given statement. + + - this month's refactoring of attribute instrumentation changes + the "copy-on-load" behavior we've had since midway through 0.3 + with "copy-on-modify" in most cases. This takes a sizable chunk + of latency out of load operations and overall does less work + as only attributes which are actually modified get their + "committed state" copied. Only "mutable scalar" attributes + (i.e. a pickled object or other mutable item), the reason for + the copy-on-load change in the first place, retain the old + behavior. + + - a slight behavioral change to attributes is, del'ing an attribute + does *not* cause the lazyloader of that attribute to fire off again; + the "del" makes the effective value of the attribute "None". To + re-trigger the "loader" for an attribute, use + session.expire(instance, [attrname]). + + - query.filter(SomeClass.somechild == None), when comparing + a many-to-one property to None, properly generates "id IS NULL" + including that the NULL is on the right side. + + - query.order_by() takes into account aliased joins, i.e. + query.join('orders', aliased=True).order_by(Order.id) + + - eagerload(), lazyload(), eagerload_all() take an optional + second class-or-mapper argument, which will select the mapper + to apply the option towards. This can select among other + mappers which were added using add_entity(). + + - eagerloading will work with mappers added via add_entity(). + + - added "cascade delete" behavior to "dynamic" relations just like + that of regular relations. if passive_deletes flag (also just added) + is not set, a delete of the parent item will trigger a full load of + the child items so that they can be deleted or updated accordingly. + + - also with dynamic, implemented correct count() behavior as well + as other helper methods. + + - fix to cascades on polymorphic relations, such that cascades + from an object to a polymorphic collection continue cascading + along the set of attributes specific to each element in the collection. + + - query.get() and query.load() do not take existing filter or other + criterion into account; these methods *always* look up the given id + in the database or return the current instance from the identity map, + disregarding any existing filter, join, group_by or other criterion + which has been configured. [ticket:893] + + - added support for version_id_col in conjunction with inheriting mappers. + version_id_col is typically set on the base mapper in an inheritance + relationship where it takes effect for all inheriting mappers. + [ticket:883] + + - relaxed rules on column_property() expressions having labels; any + ColumnElement is accepted now, as the compiler auto-labels non-labeled + ColumnElements now. a selectable, like a select() statement, still + requires conversion to ColumnElement via as_scalar() or label(). + + - fixed backref bug where you could not del instance.attr if attr + was None + + - several ORM attributes have been removed or made private: + mapper.get_attr_by_column(), mapper.set_attr_by_column(), + mapper.pks_by_table, mapper.cascade_callable(), + MapperProperty.cascade_callable(), mapper.canload(), + mapper.save_obj(), mapper.delete_obj(), mapper._mapper_registry, + attributes.AttributeManager + + - Assigning an incompatible collection type to a relation attribute now + raises TypeError instead of sqlalchemy's ArgumentError. + + - Bulk assignment of a MappedCollection now raises an error if a key in the + incoming dictionary does not match the key that the collection's keyfunc + would use for that value. [ticket:886] + + - Custom collections can now specify a @converter method to translate + objects used in "bulk" assignment into a stream of values, as in:: + + obj.col = [newval1, newval2] + # or + obj.dictcol = {'foo': newval1, 'bar': newval2} + + The MappedCollection uses this hook to ensure that incoming key/value + pairs are sane from the collection's perspective. + + - fixed endless loop issue when using lazy="dynamic" on both + sides of a bi-directional relationship [ticket:872] + + - more fixes to the LIMIT/OFFSET aliasing applied with Query + eagerloads, + in this case when mapped against a select statement [ticket:904] + + - fix to self-referential eager loading such that if the same mapped + instance appears in two or more distinct sets of columns in the same + result set, its eagerly loaded collection will be populated regardless + of whether or not all of the rows contain a set of "eager" columns for + that collection. this would also show up as a KeyError when fetching + results with join_depth turned on. + + - fixed bug where Query would not apply a subquery to the SQL when LIMIT + was used in conjunction with an inheriting mapper where the eager + loader was only in the parent mapper. + + - clarified the error message which occurs when you try to update() + an instance with the same identity key as an instance already present + in the session. + + - some clarifications and fixes to merge(instance, dont_load=True). + fixed bug where lazy loaders were getting disabled on returned instances. + Also, we currently do not support merging an instance which has uncommitted + changes on it, in the case that dont_load=True is used....this will + now raise an error. This is due to complexities in merging the + "committed state" of the given instance to correctly correspond to the + newly copied instance, as well as other modified state. + Since the use case for dont_load=True is caching, the given instances + shouldn't have any uncommitted changes on them anyway. + We also copy the instances over without using any events now, so that + the 'dirty' list on the new session remains unaffected. + + - fixed bug which could arise when using session.begin_nested() in conjunction + with more than one level deep of enclosing session.begin() statements + + - fixed session.refresh() with instance that has custom entity_name + [ticket:914] + +- dialects + + - sqlite SLDate type will not erroneously render "microseconds" portion + of a datetime or time object. + + - oracle + - added disconnect detection support for Oracle + - some cleanup to binary/raw types so that cx_oracle.LOB is detected + on an ad-hoc basis [ticket:902] + + - MSSQL + - PyODBC no longer has a global "set nocount on". + - Fix non-identity integer PKs on autload [ticket:824] + - Better support for convert_unicode [ticket:839] + - Less strict date conversion for pyodbc/adodbapi [ticket:842] + - Schema-qualified tables / autoload [ticket:901] + + - Firebird backend + + - does properly reflect domains (partially fixing [ticket:410]) and + PassiveDefaults + + - reverted to use default poolclass (was set to SingletonThreadPool in + 0.4.0 [3562] for test purposes) + + - map func.length() to 'char_length' (easily overridable with the UDF + 'strlen' on old versions of Firebird) + + +0.4.1 +----- + - sql - - all "type" keyword arguments, such as those to bindparam(), column(), - Column(), and func.(), renamed to "type_". those objects - still name their "type" attribute as "type". - - transactions: - - added context manager (with statement) support for transactions - - added support for two phase commit, works with mysql and postgres so far. - - added a subtransaction implementation that uses savepoints. - - added support for savepoints. + + - the "shortname" keyword parameter on bindparam() has been + deprecated. + + - Added contains operator (generates a "LIKE %%" clause). + + - anonymous column expressions are automatically labeled. + e.g. select([x* 5]) produces "SELECT x * 5 AS anon_1". + This allows the labelname to be present in the cursor.description + which can then be appropriately matched to result-column processing + rules. (we can't reliably use positional tracking for result-column + matches since text() expressions may represent multiple columns). + + - operator overloading is now controlled by TypeEngine objects - the + one built-in operator overload so far is String types overloading + '+' to be the string concatenation operator. + User-defined types can also define their own operator overloading + by overriding the adapt_operator(self, op) method. + + - untyped bind parameters on the right side of a binary expression + will be assigned the type of the left side of the operation, to better + enable the appropriate bind parameter processing to take effect + [ticket:819] + + - Removed regular expression step from most statement compilations. + Also fixes [ticket:833] + + - Fixed empty (zero column) sqlite inserts, allowing inserts on + autoincrementing single column tables. + + - Fixed expression translation of text() clauses; this repairs various + ORM scenarios where literal text is used for SQL expressions + + - Removed ClauseParameters object; compiled.params returns a regular + dictionary now, as well as result.last_inserted_params() / + last_updated_params(). + + - Fixed INSERT statements w.r.t. primary key columns that have + SQL-expression based default generators on them; SQL expression + executes inline as normal but will not trigger a "postfetch" condition + for the column, for those DB's who provide it via cursor.lastrowid + + - func. objects can be pickled/unpickled [ticket:844] + + - rewrote and simplified the system used to "target" columns across + selectable expressions. On the SQL side this is represented by the + "corresponding_column()" method. This method is used heavily by the ORM + to "adapt" elements of an expression to similar, aliased expressions, + as well as to target result set columns originally bound to a + table or selectable to an aliased, "corresponding" expression. The new + rewrite features completely consistent and accurate behavior. + + - Added a field ("info") for storing arbitrary data on schema items + [ticket:573] + + - The "properties" collection on Connections has been renamed "info" to + match schema's writable collections. Access is still available via + the "properties" name until 0.5. + + - fixed the close() method on Transaction when using strategy='threadlocal' + + - fix to compiled bind parameters to not mistakenly populate None + [ticket:853] + + - ._execute_clauseelement becomes a public method + Connectable.execute_clauseelement + +- orm + - eager loading with LIMIT/OFFSET applied no longer adds the primary + table joined to a limited subquery of itself; the eager loads now + join directly to the subquery which also provides the primary table's + columns to the result set. This eliminates a JOIN from all eager loads + with LIMIT/OFFSET. [ticket:843] + + - session.refresh() and session.expire() now support an additional argument + "attribute_names", a list of individual attribute keynames to be refreshed + or expired, allowing partial reloads of attributes on an already-loaded + instance. [ticket:802] + + - added op() operator to instrumented attributes; i.e. + User.name.op('ilike')('%somename%') [ticket:767] + + - Mapped classes may now define __eq__, __hash__, and __nonzero__ methods + with arbitrary semantics. The orm now handles all mapped instances on + an identity-only basis. (e.g. 'is' vs '==') [ticket:676] + + - the "properties" accessor on Mapper is removed; it now throws an informative + exception explaining the usage of mapper.get_property() and + mapper.iterate_properties + + - added having() method to Query, applies HAVING to the generated statement + in the same way as filter() appends to the WHERE clause. + + - The behavior of query.options() is now fully based on paths, i.e. an + option such as eagerload_all('x.y.z.y.x') will apply eagerloading to + only those paths, i.e. and not 'x.y.x'; eagerload('children.children') + applies only to exactly two-levels deep, etc. [ticket:777] + + - PickleType will compare using `==` when set up with mutable=False, + and not the `is` operator. To use `is` or any other comparator, send + in a custom comparison function using PickleType(comparator=my_custom_comparator). + + - query doesn't throw an error if you use distinct() and an order_by() + containing UnaryExpressions (or other) together [ticket:848] + + - order_by() expressions from joined tables are properly added to columns + clause when using distinct() [ticket:786] + + - fixed error where Query.add_column() would not accept a class-bound + attribute as an argument; Query also raises an error if an invalid + argument was sent to add_column() (at instances() time) [ticket:858] + + - added a little more checking for garbage-collection dereferences in + InstanceState.__cleanup() to reduce "gc ignored" errors on app + shutdown + + - The session API has been solidified: + + - It's an error to session.save() an object which is already + persistent [ticket:840] + + - It's an error to session.delete() an object which is *not* + persistent. + + - session.update() and session.delete() raise an error when updating + or deleting an instance that is already in the session with a + different identity. + + - The session checks more carefully when determining "object X already + in another session"; e.g. if you pickle a series of objects and + unpickle (i.e. as in a Pylons HTTP session or similar), they can go + into a new session without any conflict + + - merge() includes a keyword argument "dont_load=True". setting this + flag will cause the merge operation to not load any data from the + database in response to incoming detached objects, and will accept + the incoming detached object as though it were already present in + that session. Use this to merge detached objects from external + caching systems into the session. + + - Deferred column attributes no longer trigger a load operation when the + attribute is assigned to. In those cases, the newly assigned value + will be present in the flushes' UPDATE statement unconditionally. + + - Fixed a truncation error when re-assigning a subset of a collection + (obj.relation = obj.relation[1:]) [ticket:834] + + - De-cruftified backref configuration code, backrefs which step on + existing properties now raise an error [ticket:832] + + - Improved behavior of add_property() etc., fixed [ticket:831] involving + synonym/deferred. + + - Fixed clear_mappers() behavior to better clean up after itself. + + - Fix to "row switch" behavior, i.e. when an INSERT/DELETE is combined + into a single UPDATE; many-to-many relations on the parent object + update properly. [ticket:841] + + - Fixed __hash__ for association proxy- these collections are unhashable, + just like their mutable Python counterparts. + + - Added proxying of save_or_update, __contains__ and __iter__ methods for + scoped sessions. + + - fixed very hard-to-reproduce issue where by the FROM clause of Query + could get polluted by certain generative calls [ticket:852] + +- dialects + + - Added experimental support for MaxDB (versions >= 7.6.03.007 only). + + - oracle will now reflect "DATE" as an OracleDateTime column, not + OracleDate + + - added awareness of schema name in oracle table_names() function, + fixes metadata.reflect(schema='someschema') [ticket:847] + + - MSSQL anonymous labels for selection of functions made deterministic + + - sqlite will reflect "DECIMAL" as a numeric column. + + - Made access dao detection more reliable [ticket:828] + + - Renamed the Dialect attribute 'preexecute_sequences' to + 'preexecute_pk_sequences'. An attribute porxy is in place for + out-of-tree dialects using the old name. + + - Added test coverage for unknown type reflection. Fixed sqlite/mysql + handling of type reflection for unknown types. + + - Added REAL for mysql dialect (for folks exploiting the + REAL_AS_FLOAT sql mode). + + - mysql Float, MSFloat and MSDouble constructed without arguments + now produce no-argument DDL, e.g.'FLOAT'. + +- misc + + - Removed unused util.hash(). + + +0.4.0 +----- + +- (see 0.4.0beta1 for the start of major changes against 0.3, + as well as http://www.sqlalchemy.org/trac/wiki/WhatsNewIn04 ) + +- Added initial Sybase support (mxODBC so far) [ticket:785] + +- Added partial index support for PostgreSQL. Use the postgres_where keyword + on the Index. + +- string-based query param parsing/config file parser understands + wider range of string values for booleans [ticket:817] + +- backref remove object operation doesn't fail if the other-side + collection doesn't contain the item, supports noload collections + [ticket:813] + +- removed __len__ from "dynamic" collection as it would require issuing + a SQL "count()" operation, thus forcing all list evaluations to issue + redundant SQL [ticket:818] + +- inline optimizations added to locate_dirty() which can greatly speed up + repeated calls to flush(), as occurs with autoflush=True [ticket:816] + +- The IdentifierPreprarer's _requires_quotes test is now regex based. Any + out-of-tree dialects that provide custom sets of legal_characters or + illegal_initial_characters will need to move to regexes or override + _requires_quotes. + +- Firebird has supports_sane_rowcount and supports_sane_multi_rowcount set + to False due to ticket #370 (right way). + +- Improvements and fixes on Firebird reflection: + . FBDialect now mimics OracleDialect, regarding case-sensitivity of TABLE and + COLUMN names (see 'case_sensitive remotion' topic on this current file). + . FBDialect.table_names() doesn't bring system tables (ticket:796). + . FB now reflects Column's nullable property correctly. + +- Fixed SQL compiler's awareness of top-level column labels as used + in result-set processing; nested selects which contain the same column + names don't affect the result or conflict with result-column metadata. + +- query.get() and related functions (like many-to-one lazyloading) + use compile-time-aliased bind parameter names, to prevent + name conflicts with bind parameters that already exist in the + mapped selectable. + +- Fixed three- and multi-level select and deferred inheritance loading + (i.e. abc inheritance with no select_table), [ticket:795] + +- Ident passed to id_chooser in shard.py always a list. + +- The no-arg ResultProxy._row_processor() is now the class attribute + `_process_row`. + +- Added support for returning values from inserts and udpates for + PostgreSQL 8.2+. [ticket:797] + +- PG reflection, upon seeing the default schema name being used explicitly + as the "schema" argument in a Table, will assume that this is the the + user's desired convention, and will explicitly set the "schema" argument + in foreign-key-related reflected tables, thus making them match only + with Table constructors that also use the explicit "schema" argument + (even though its the default schema). + In other words, SA assumes the user is being consistent in this usage. + +- fixed sqlite reflection of BOOL/BOOLEAN [ticket:808] + +- Added support for UPDATE with LIMIT on mysql. + +- null foreign key on a m2o doesn't trigger a lazyload [ticket:803] + +- oracle does not implicitly convert to unicode for non-typed result + sets (i.e. when no TypeEngine/String/Unicode type is even being used; + previously it was detecting DBAPI types and converting regardless). + should fix [ticket:800] + +- fix to anonymous label generation of long table/column names [ticket:806] + +- Firebird dialect now uses SingletonThreadPool as poolclass. + +- Firebird now uses dialect.preparer to format sequences names + +- Fixed breakage with postgres and multiple two-phase transactions. Two-phase + commits and and rollbacks didn't automatically end up with a new transaction + as the usual dbapi commits/rollbacks do. [ticket:810] + +- Added an option to the _ScopedExt mapper extension to not automatically + save new objects to session on object initialization. + +- fixed Oracle non-ansi join syntax + +- PickleType and Interval types (on db not supporting it natively) are now + slightly faster. + +- Added Float and Time types to Firebird (FBFloat and FBTime). Fixed + BLOB SUB_TYPE for TEXT and Binary types. + +- Changed the API for the in_ operator. in_() now accepts a single argument + that is a sequence of values or a selectable. The old API of passing in + values as varargs still works but is deprecated. + + +0.4.0beta6 +---------- + +- The Session identity map is now *weak referencing* by default, use + weak_identity_map=False to use a regular dict. The weak dict we are using + is customized to detect instances which are "dirty" and maintain a + temporary strong reference to those instances until changes are flushed. + +- Mapper compilation has been reorganized such that most compilation occurs + upon mapper construction. This allows us to have fewer calls to + mapper.compile() and also to allow class-based properties to force a + compilation (i.e. User.addresses == 7 will compile all mappers; this is + [ticket:758]). The only caveat here is that an inheriting mapper now + looks for its inherited mapper upon construction; so mappers within + inheritance relationships need to be constructed in inheritance order + (which should be the normal case anyway). + +- added "FETCH" to the keywords detected by Postgres to indicate a + result-row holding statement (i.e. in addition to "SELECT"). + +- Added full list of SQLite reserved keywords so that they get escaped + properly. + +- Tightened up the relationship between the Query's generation of "eager + load" aliases, and Query.instances() which actually grabs the eagerly + loaded rows. If the aliases were not specifically generated for that + statement by EagerLoader, the EagerLoader will not take effect when the + rows are fetched. This prevents columns from being grabbed accidentally + as being part of an eager load when they were not meant for such, which + can happen with textual SQL as well as some inheritance situations. It's + particularly important since the "anonymous aliasing" of columns uses + simple integer counts now to generate labels. + +- Removed "parameters" argument from clauseelement.compile(), replaced with + "column_keys". The parameters sent to execute() only interact with the + insert/update statement compilation process in terms of the column names + present but not the values for those columns. Produces more consistent + execute/executemany behavior, simplifies things a bit internally. + +- Added 'comparator' keyword argument to PickleType. By default, "mutable" + PickleType does a "deep compare" of objects using their dumps() + representation. But this doesn't work for dictionaries. Pickled objects + which provide an adequate __eq__() implementation can be set up with + "PickleType(comparator=operator.eq)" [ticket:560] + +- Added session.is_modified(obj) method; performs the same "history" + comparison operation as occurs within a flush operation; setting + include_collections=False gives the same result as is used when the flush + determines whether or not to issue an UPDATE for the instance's row. + +- Added "schema" argument to Sequence; use this with Postgres /Oracle when + the sequence is located in an alternate schema. Implements part of + [ticket:584], should fix [ticket:761]. + +- Fixed reflection of the empty string for mysql enums. + +- Changed MySQL dialect to use the older LIMIT , syntax + instead of LIMIT OFFSET for folks using 3.23. [ticket:794] + +- Added 'passive_deletes="all"' flag to relation(), disables all nulling-out + of foreign key attributes during a flush where the parent object is + deleted. + +- Column defaults and onupdates, executing inline, will add parenthesis for + subqueries and other parenthesis-requiring expressions + +- The behavior of String/Unicode types regarding that they auto-convert to + TEXT/CLOB when no length is present now occurs *only* for an exact type of + String or Unicode with no arguments. If you use VARCHAR or NCHAR + (subclasses of String/Unicode) with no length, they will be interpreted by + the dialect as VARCHAR/NCHAR; no "magic" conversion happens there. This + is less surprising behavior and in particular this helps Oracle keep + string-based bind parameters as VARCHARs and not CLOBs [ticket:793]. + +- Fixes to ShardedSession to work with deferred columns [ticket:771]. + +- User-defined shard_chooser() function must accept "clause=None" argument; + this is the ClauseElement passed to session.execute(statement) and can be + used to determine correct shard id (since execute() doesn't take an + instance.) + +- Adjusted operator precedence of NOT to match '==' and others, so that + ~(x y) produces NOT (x y), which is better compatible + with older MySQL versions. [ticket:764]. This doesn't apply to "~(x==y)" + as it does in 0.3 since ~(x==y) compiles to "x != y", but still applies + to operators like BETWEEN. + +- Other tickets: [ticket:768], [ticket:728], [ticket:779], [ticket:757] + +0.4.0beta5 +---------- + +- Connection pool fixes; the better performance of beta4 remains but fixes + "connection overflow" and other bugs which were present (like + [ticket:754]). + +- Fixed bugs in determining proper sync clauses from custom inherit + conditions. [ticket:769] + +- Extended 'engine_from_config' coercion for QueuePool size / overflow. + [ticket:763] + +- mysql views can be reflected again. [ticket:748] + +- AssociationProxy can now take custom getters and setters. + +- Fixed malfunctioning BETWEEN in orm queries. + +- Fixed OrderedProperties pickling [ticket:762] + +- SQL-expression defaults and sequences now execute "inline" for all + non-primary key columns during an INSERT or UPDATE, and for all columns + during an executemany()-style call. inline=True flag on any insert/update + statement also forces the same behavior with a single execute(). + result.postfetch_cols() is a collection of columns for which the previous + single insert or update statement contained a SQL-side default expression. + +- Fixed PG executemany() behavior, [ticket:759] + +- postgres reflects tables with autoincrement=False for primary key columns + which have no defaults. + +- postgres no longer wraps executemany() with individual execute() calls, + instead favoring performance. "rowcount"/"concurrency" checks with + deleted items (which use executemany) are disabled with PG since psycopg2 + does not report proper rowcount for executemany(). + +- Tickets fixed: + + - [ticket:742] + - [ticket:748] + - [ticket:760] + - [ticket:762] + - [ticket:763] + +0.4.0beta4 +---------- + +- Tidied up what ends up in your namespace when you 'from sqlalchemy import *': + + - 'table' and 'column' are no longer imported. They remain available by + direct reference (as in 'sql.table' and 'sql.column') or a glob import + from the sql package. It was too easy to accidentally use a + sql.expressions.table instead of schema.Table when just starting out + with SQLAlchemy, likewise column. + + - Internal-ish classes like ClauseElement, FromClause, NullTypeEngine, + etc., are also no longer imported into your namespace + + - The 'Smallinteger' compatiblity name (small i!) is no longer imported, + but remains in schema.py for now. SmallInteger (big I!) is still + imported. + +- The connection pool uses a "threadlocal" strategy internally to return + the same connection already bound to a thread, for "contextual" connections; + these are the connections used when you do a "connectionless" execution + like insert().execute(). This is like a "partial" version of the + "threadlocal" engine strategy but without the thread-local transaction part + of it. We're hoping it reduces connection pool overhead as well as + database usage. However, if it proves to impact stability in a negative way, + we'll roll it right back. + +- Fix to bind param processing such that "False" values (like blank strings) + still get processed/encoded. + +- Fix to select() "generative" behavior, such that calling column(), + select_from(), correlate(), and with_prefix() does not modify the + original select object [ticket:752] + +- Added a "legacy" adapter to types, such that user-defined TypeEngine + and TypeDecorator classes which define convert_bind_param() and/or + convert_result_value() will continue to function. Also supports + calling the super() version of those methods. + +- Added session.prune(), trims away instances cached in a session that + are no longer referenced elsewhere. (A utility for strong-ref + identity maps). + +- Added close() method to Transaction. Closes out a transaction using + rollback if it's the outermost transaction, otherwise just ends + without affecting the outer transaction. + +- Transactional and non-transactional Session integrates better with + bound connection; a close() will ensure that connection + transactional state is the same as that which existed on it before + being bound to the Session. + +- Modified SQL operator functions to be module-level operators, + allowing SQL expressions to be pickleable. [ticket:735] + +- Small adjustment to mapper class.__init__ to allow for Py2.6 + object.__init__() behavior. + +- Fixed 'prefix' argument for select() + +- Connection.begin() no longer accepts nested=True, this logic is now + all in begin_nested(). + +- Fixes to new "dynamic" relation loader involving cascades + +- Tickets fixed: + + - [ticket:735] + - [ticket:752] + +0.4.0beta3 +---------- + +- SQL types optimization: + + - New performance tests show a combined mass-insert/mass-select test as + having 68% fewer function calls than the same test run against 0.3. + + - General performance improvement of result set iteration is around 10-20%. + + - In types.AbstractType, convert_bind_param() and convert_result_value() + have migrated to callable-returning bind_processor() and + result_processor() methods. If no callable is returned, no pre/post + processing function is called. + + - Hooks added throughout base/sql/defaults to optimize the calling of bind + aram/result processors so that method call overhead is minimized. + + - Support added for executemany() scenarios such that unneeded "last row id" + logic doesn't kick in, parameters aren't excessively traversed. + +- Added 'inherit_foreign_keys' arg to mapper(). + +- Added support for string date passthrough in sqlite. + +- Tickets fixed: + + - [ticket:738] + - [ticket:739] + - [ticket:743] + - [ticket:744] + +0.4.0beta2 +---------- + +- mssql improvements. + +- oracle improvements. + +- Auto-commit after LOAD DATA INFILE for mysql. + +- A rudimental SessionExtension class has been added, allowing user-defined + functionality to take place at flush(), commit(), and rollback() boundaries. + +- Added engine_from_config() function for helping to create_engine() from an + .ini style config. + +- base_mapper() becomes a plain attribute. + +- session.execute() and scalar() can search for a Table with which to bind from + using the given ClauseElement. + +- Session automatically extrapolates tables from mappers with binds, also uses + base_mapper so that inheritance hierarchies bind automatically. + +- Moved ClauseVisitor traversal back to inlined non-recursive. + +- Tickets fixed: + + - [ticket:730] + - [ticket:732] + - [ticket:733] + - [ticket:734] + +0.4.0beta1 +---------- + +- orm + + - Speed! Along with recent speedups to ResultProxy, total number of function + calls significantly reduced for large loads. + + - test/perf/masseagerload.py reports 0.4 as having the fewest number of + function calls across all SA versions (0.1, 0.2, and 0.3). + + - New collection_class api and implementation [ticket:213]. Collections are + now instrumented via decorations rather than proxying. You can now have + collections that manage their own membership, and your class instance will + be directly exposed on the relation property. The changes are transparent + for most users. + + - InstrumentedList (as it was) is removed, and relation properties no + longer have 'clear()', '.data', or any other added methods beyond those + provided by the collection type. You are free, of course, to add them to + a custom class. + + - __setitem__-like assignments now fire remove events for the existing + value, if any. + + - dict-likes used as collection classes no longer need to change __iter__ + semantics- itervalues() is used by default instead. This is a backwards + incompatible change. + + - Subclassing dict for a mapped collection is no longer needed in most + cases. orm.collections provides canned implementations that key objects + by a specified column or a custom function of your choice. + + - Collection assignment now requires a compatible type- assigning None to + clear a collection or assigning a list to a dict collection will now + raise an argument error. + + - AttributeExtension moved to interfaces, and .delete is now .remove The + event method signature has also been swapped around. + + - Major overhaul for Query: + + - All selectXXX methods are deprecated. Generative methods are now the + standard way to do things, i.e. filter(), filter_by(), all(), one(), + etc. Deprecated methods are docstring'ed with their new replacements. + + - Class-level properties are now usable as query elements... no more + '.c.'! "Class.c.propname" is now superceded by "Class.propname". All + clause operators are supported, as well as higher level operators such + as Class.prop== for scalar attributes, + Class.prop.contains() and Class.prop.any() for collection-based attributes (all are also + negatable). Table-based column expressions as well as columns mounted + on mapped classes via 'c' are of course still fully available and can be + freely mixed with the new attributes. [ticket:643] + + - Removed ancient query.select_by_attributename() capability. + + - The aliasing logic used by eager loading has been generalized, so that + it also adds full automatic aliasing support to Query. It's no longer + necessary to create an explicit Alias to join to the same tables + multiple times; *even for self-referential relationships*. + + - join() and outerjoin() take arguments "aliased=True". Yhis causes + their joins to be built on aliased tables; subsequent calls to + filter() and filter_by() will translate all table expressions (yes, + real expressions using the original mapped Table) to be that of the + Alias for the duration of that join() (i.e. until reset_joinpoint() or + another join() is called). + + - join() and outerjoin() take arguments "id=". When used + with "aliased=True", the id can be referenced by add_entity(cls, + id=) so that you can select the joined instances even if + they're from an alias. + + - join() and outerjoin() now work with self-referential relationships! + Using "aliased=True", you can join as many levels deep as desired, + i.e. query.join(['children', 'children'], aliased=True); filter + criterion will be against the rightmost joined table + + - Added query.populate_existing(), marks the query to reload all + attributes and collections of all instances touched in the query, + including eagerly-loaded entities. [ticket:660] + + - Added eagerload_all(), allows eagerload_all('x.y.z') to specify eager + loading of all properties in the given path. + + - Major overhaul for Session: + + - New function which "configures" a session called "sessionmaker()". Send + various keyword arguments to this function once, returns a new class + which creates a Session against that stereotype. + + - SessionTransaction removed from "public" API. You now can call begin()/ + commit()/rollback() on the Session itself. + + - Session also supports SAVEPOINT transactions; call begin_nested(). + + - Session supports two-phase commit behavior when vertically or + horizontally partitioning (i.e., using more than one engine). Use + twophase=True. + + - Session flag "transactional=True" produces a session which always places + itself into a transaction when first used. Upon commit(), rollback() or + close(), the transaction ends; but begins again on the next usage. + + - Session supports "autoflush=True". This issues a flush() before each + query. Use in conjunction with transactional, and you can just + save()/update() and then query, the new objects will be there. Use + commit() at the end (or flush() if non-transactional) to flush remaining + changes. + + - New scoped_session() function replaces SessionContext and assignmapper. + Builds onto "sessionmaker()" concept to produce a class whos Session() + construction returns the thread-local session. Or, call all Session + methods as class methods, i.e. Session.save(foo); Session.commit(). + just like the old "objectstore" days. + + - Added new "binds" argument to Session to support configuration of + multiple binds with sessionmaker() function. + + - A rudimental SessionExtension class has been added, allowing + user-defined functionality to take place at flush(), commit(), and + rollback() boundaries. + + - Query-based relation()s available with dynamic_loader(). This is a + *writable* collection (supporting append() and remove()) which is also a + live Query object when accessed for reads. Ideal for dealing with very + large collections where only partial loading is desired. + + - flush()-embedded inline INSERT/UPDATE expressions. Assign any SQL + expression, like "sometable.c.column + 1", to an instance's attribute. + Upon flush(), the mapper detects the expression and embeds it directly in + the INSERT or UPDATE statement; the attribute gets deferred on the + instance so it loads the new value the next time you access it. + + - A rudimental sharding (horizontal scaling) system is introduced. This + system uses a modified Session which can distribute read and write + operations among multiple databases, based on user-defined functions + defining the "sharding strategy". Instances and their dependents can be + distributed and queried among multiple databases based on attribute + values, round-robin approaches or any other user-defined + system. [ticket:618] + + - Eager loading has been enhanced to allow even more joins in more places. + It now functions at any arbitrary depth along self-referential and + cyclical structures. When loading cyclical structures, specify + "join_depth" on relation() indicating how many times you'd like the table + to join to itself; each level gets a distinct table alias. The alias + names themselves are generated at compile time using a simple counting + scheme now and are a lot easier on the eyes, as well as of course + completely deterministic. [ticket:659] + + - Added composite column properties. This allows you to create a type which + is represented by more than one column, when using the ORM. Objects of + the new type are fully functional in query expressions, comparisons, + query.get() clauses, etc. and act as though they are regular single-column + scalars... except they're not! Use the function composite(cls, *columns) + inside of the mapper's "properties" dict, and instances of cls will be + created/mapped to a single attribute, comprised of the values correponding + to *columns. [ticket:211] + + - Improved support for custom column_property() attributes which feature + correlated subqueries, works better with eager loading now. + + - Primary key "collapse" behavior; the mapper will analyze all columns in + its given selectable for primary key "equivalence", that is, columns which + are equivalent via foreign key relationship or via an explicit + inherit_condition. primarily for joined-table inheritance scenarios where + different named PK columns in inheriting tables should "collapse" into a + single-valued (or fewer-valued) primary key. Fixes things like + [ticket:611]. + + - Joined-table inheritance will now generate the primary key columns of all + inherited classes against the root table of the join only. This implies + that each row in the root table is distinct to a single instance. If for + some rare reason this is not desireable, explicit primary_key settings on + individual mappers will override it. + + - When "polymorphic" flags are used with joined-table or single-table + inheritance, all identity keys are generated against the root class of the + inheritance hierarchy; this allows query.get() to work polymorphically + using the same caching semantics as a non-polymorphic get. Note that this + currently does not work with concrete inheritance. + + - Secondary inheritance loading: polymorphic mappers can be constructed + *without* a select_table argument. inheriting mappers whose tables were + not represented in the initial load will issue a second SQL query + immediately, once per instance (i.e. not very efficient for large lists), + in order to load the remaining columns. + + - Secondary inheritance loading can also move its second query into a + column-level "deferred" load, via the "polymorphic_fetch" argument, which + can be set to 'select' or 'deferred' + + - It's now possible to map only a subset of available selectable columns + onto mapper properties, using include_columns/exclude_columns. + [ticket:696]. + + - Added undefer_group() MapperOption, sets a set of "deferred" columns + joined by a "group" to load as "undeferred". + + - Rewrite of the "deterministic alias name" logic to be part of the SQL + layer, produces much simpler alias and label names more in the style of + Hibernate + +- sql + + - Speed! Clause compilation as well as the mechanics of SQL constructs have + been streamlined and simplified to a signficant degree, for a 20-30% + improvement of the statement construction/compilation overhead of 0.3. + + - All "type" keyword arguments, such as those to bindparam(), column(), + Column(), and func.(), renamed to "type_". Those objects still + name their "type" attribute as "type". + + - case_sensitive=(True|False) setting removed from schema items, since + checking this state added a lot of method call overhead and there was no + decent reason to ever set it to False. Table and column names which are + all lower case will be treated as case-insenstive (yes we adjust for + Oracle's UPPERCASE style too). + + - Transactions: + + - Added context manager (with statement) support for transactions. + - Added support for two phase commit, works with mysql and postgres so far. + - Added a subtransaction implementation that uses savepoints. + - Added support for savepoints. + - MetaData: + + - Tables can be reflected from the database en-masse without declaring + them in advance. MetaData(engine, reflect=True) will load all tables + present in the database, or use metadata.reflect() for finer control. - DynamicMetaData has been renamed to ThreadLocalMetaData + - The ThreadLocalMetaData constructor now takes no arguments. - BoundMetaData has been removed- regular MetaData is equivalent - - Numeric and Float types now have an "asdecimal" flag; defaults to - True for Numeric, False for Float. when True, values are returned as - decimal.Decimal objects; when False, values are returned as float(). - the defaults of True/False are already the behavior for PG and MySQL's - DBAPI modules. [ticket:646] - - new SQL operator implementation which removes all hardcoded operators - from expression structures and moves them into compilation; - allows greater flexibility of operator compilation; for example, "+" - compiles to "||" when used in a string context, or "concat(a,b)" on - MySQL; whereas in a numeric context it compiles to "+". fixes [ticket:475]. - - "anonymous" alias and label names are now generated at SQL compilation - time in a completely deterministic fashion...no more random hex IDs - - significant architectural overhaul to SQL elements (ClauseElement). - all elements share a common "mutability" framework which allows a - consistent approach to in-place modifications of elements as well as - generative behavior. improves stability of the ORM which makes - heavy usage of mutations to SQL expressions. - - select() and union()'s now have "generative" behavior. methods like + + - Numeric and Float types now have an "asdecimal" flag; defaults to True for + Numeric, False for Float. When True, values are returned as + decimal.Decimal objects; when False, values are returned as float(). The + defaults of True/False are already the behavior for PG and MySQL's DBAPI + modules. [ticket:646] + + - New SQL operator implementation which removes all hardcoded operators from + expression structures and moves them into compilation; allows greater + flexibility of operator compilation; for example, "+" compiles to "||" + when used in a string context, or "concat(a,b)" on MySQL; whereas in a + numeric context it compiles to "+". Fixes [ticket:475]. + + - "Anonymous" alias and label names are now generated at SQL compilation + time in a completely deterministic fashion... no more random hex IDs + + - Significant architectural overhaul to SQL elements (ClauseElement). All + elements share a common "mutability" framework which allows a consistent + approach to in-place modifications of elements as well as generative + behavior. Improves stability of the ORM which makes heavy usage of + mutations to SQL expressions. + + - select() and union()'s now have "generative" behavior. Methods like order_by() and group_by() return a *new* instance - the original instance - is left unchanged. non-generative methods remain as well. - - the internals of select/union vastly simplified - all decision making + is left unchanged. Non-generative methods remain as well. + + - The internals of select/union vastly simplified- all decision making regarding "is subquery" and "correlation" pushed to SQL generation phase. - select() elements are now *never* mutated by their enclosing containers - or by any dialect's compilation process [ticket:52] [ticket:569] + select() elements are now *never* mutated by their enclosing containers or + by any dialect's compilation process [ticket:52] [ticket:569] + - select(scalar=True) argument is deprecated; use select(..).as_scalar(). - the resulting object obeys the full "column" interface and plays better - within expressions - - added select().with_prefix('foo') allowing any set of keywords to be + The resulting object obeys the full "column" interface and plays better + within expressions. + + - Added select().with_prefix('foo') allowing any set of keywords to be placed before the columns clause of the SELECT [ticket:504] - - added array slice support to row[] [ticket:686] - - result sets make a better attempt at matching the DBAPI types present - in cursor.description to the TypeEngine objects defined by the dialect, - which are then used for result-processing. Note this only takes effect - for textual SQL; constructed SQL statements always have an explicit type map. - - result sets from CRUD operations close their underlying cursor immediately. - will also autoclose the connection if defined for the operation; this + + - Added array slice support to row[] [ticket:686] + + - Result sets make a better attempt at matching the DBAPI types present in + cursor.description to the TypeEngine objects defined by the dialect, which + are then used for result-processing. Note this only takes effect for + textual SQL; constructed SQL statements always have an explicit type map. + + - Result sets from CRUD operations close their underlying cursor immediately + and will also autoclose the connection if defined for the operation; this allows more efficient usage of connections for successive CRUD operations with less chance of "dangling connections". - - Column defaults and onupdate Python functions (i.e. passed to ColumnDefault) - may take zero or one arguments; the one argument is the ExecutionContext, - from which you can call "context.parameters[someparam]" to access the other - bind parameter values affixed to the statement [ticket:559] - - added "explcit" create/drop/execute support for sequences - (i.e. you can pass a "connectable" to each of those methods - on Sequence) - - better quoting of identifiers when manipulating schemas - - standardized the behavior for table reflection where types can't be located; - NullType is substituted instead, warning is raised. + + - Column defaults and onupdate Python functions (i.e. passed to + ColumnDefault) may take zero or one arguments; the one argument is the + ExecutionContext, from which you can call "context.parameters[someparam]" + to access the other bind parameter values affixed to the statement + [ticket:559]. The connection used for the execution is available as well + so that you can pre-execute statements. + + - Added "explcit" create/drop/execute support for sequences (i.e. you can + pass a "connectable" to each of those methods on Sequence). + + - Better quoting of identifiers when manipulating schemas. + + - Standardized the behavior for table reflection where types can't be + located; NullType is substituted instead, warning is raised. + - ColumnCollection (i.e. the 'c' attribute on tables) follows dictionary semantics for "__contains__" [ticket:606] - + - engines + + - Speed! The mechanics of result processing and bind parameter processing + have been overhauled, streamlined and optimized to issue as little method + calls as possible. Bench tests for mass INSERT and mass rowset iteration + both show 0.4 to be over twice as fast as 0.3, using 68% fewer function + calls. + + - You can now hook into the pool lifecycle and run SQL statements or other + logic at new each DBAPI connection, pool check-out and check-in. + - Connections gain a .properties collection, with contents scoped to the lifetime of the underlying DBAPI connection + + - Removed auto_close_cursors and disallow_open_cursors arguments from Pool; + reduces overhead as cursors are normally closed by ResultProxy and + Connection. + - extensions + - proxyengine is temporarily removed, pending an actually working replacement. - - SelectResults has been replaced by Query. SelectResults / - SelectResultsExt still exist but just return a slightly modified - Query object for backwards-compatibility. join_to() method - from SelectResults isn't present anymore, need to use join(). + + - SelectResults has been replaced by Query. SelectResults / + SelectResultsExt still exist but just return a slightly modified Query + object for backwards-compatibility. join_to() method from SelectResults + isn't present anymore, need to use join(). + +- mysql + + - Table and column names loaded via reflection are now Unicode. + + - All standard column types are now supported, including SET. + + - Table reflection can now be performed in as little as one round-trip. + + - ANSI and ANSI_QUOTES sql modes are now supported. + + - Indexes are now reflected. + - postgres - - Added PGArray datatype for using postgres array datatypes + + - Added PGArray datatype for using postgres array datatypes. + - oracle - - very rudimental support for OUT parameters added; use sql.outparam(name, type) - to set up an OUT parameter, just like bindparam(); after execution, values are - avaiable via result.out_parameters dictionary. [ticket:507] + + - Very rudimental support for OUT parameters added; use sql.outparam(name, + type) to set up an OUT parameter, just like bindparam(); after execution, + values are avaiable via result.out_parameters dictionary. [ticket:507] 0.3.11 +------ + +- sql + + - tweak DISTINCT precedence for clauses like + `func.count(t.c.col.distinct())` + + - Fixed detection of internal '$' characters in :bind$params [ticket:719] + + - [ticket:768] dont assume join criterion consists only of column objects + + - adjusted operator precedence of NOT to match '==' and others, so that + ~(x==y) produces NOT (x=y), which is compatible with MySQL < 5.0 + (doesn't like "NOT x=y") [ticket:764] + - orm + - added a check for joining from A->B using join(), along two different m2m tables. this raises an error in 0.3 but is possible in 0.4 when aliases are used. [ticket:687] + + - fixed small exception throw bug in Session.merge() + + - fixed bug where mapper, being linked to a join where one table had + no PK columns, would not detect that the joined table had no PK. + + - fixed bugs in determining proper sync clauses from custom inherit + conditions [ticket:769] + + - backref remove object operation doesn't fail if the other-side + collection doesn't contain the item, supports noload collections + [ticket:813] + +- engine + + - fixed another occasional race condition which could occur + when using pool with threadlocal setting + +- mysql + - fixed specification of YEAR columns when generating schema + - mssql + - added support for TIME columns (simulated using DATETIME) [ticket:679] - - index names are now quoted when dropping from reflected tables [ticket:684] - + + - added support for BIGINT, MONEY, SMALLMONEY, UNIQUEIDENTIFIER and + SQL_VARIANT [ticket:721] + + - index names are now quoted when dropping from reflected tables + [ticket:684] + + - can now specify a DSN for PyODBC, using a URI like mssql:///?dsn=bob + +- postgres + + - when reflecting tables from alternate schemas, the "default" placed upon + the primary key, i.e. usually a sequence name, has the "schema" name + unconditionally quoted, so that schema names which need quoting are fine. + its slightly unnecessary for schema names which don't need quoting + but not harmful. + +- sqlite + - passthrough for stringified dates + +- firebird + - supports_sane_rowcount() set to False due to ticket #370 (right way). + - fixed reflection of Column's nullable property. + +- oracle + - removed LONG_STRING, LONG_BINARY from "binary" types, so type objects + don't try to read their values as LOB [ticket:622], [ticket:751] + 0.3.10 - general - a new mutex that was added in 0.3.9 causes the pool_timeout - feature to fail during a race condition; threads would - raise TimeoutError immediately with no delay if many threads + feature to fail during a race condition; threads would + raise TimeoutError immediately with no delay if many threads push the pool into overflow at the same time. this issue has been fixed. - sql @@ -249,7 +2264,7 @@ - postgres - fixed max identifier length (63) [ticket:571] - + 0.3.9 - general - better error message for NoSuchColumnError [ticket:607] @@ -258,7 +2273,7 @@ - the various "engine" arguments, such as "engine", "connectable", "engine_or_url", "bind_to", etc. are all present, but deprecated. they all get replaced by the single term "bind". you also - set the "bind" of MetaData using + set the "bind" of MetaData using metadata.bind = - ext - iteration over dict association proxies is now dict-like, not @@ -267,7 +2282,7 @@ [ticket:597], and are constructed with a thunk instead - added selectone_by() to assignmapper - orm - - forwards-compatibility with 0.4: added one(), first(), and + - forwards-compatibility with 0.4: added one(), first(), and all() to Query. almost all Query functionality from 0.4 is present in 0.3.9 for forwards-compat purposes. - reset_joinpoint() really really works this time, promise ! lets @@ -276,20 +2291,20 @@ join(['a', 'c']).filter().all() in 0.4 all join() calls start from the "root" - added synchronization to the mapper() construction step, to avoid - thread collisions when pre-existing mappers are compiling in a + thread collisions when pre-existing mappers are compiling in a different thread [ticket:613] - a warning is issued by Mapper when two primary key columns of the same name are munged into a single attribute. this happens frequently - when mapping to joins (or inheritance). + when mapping to joins (or inheritance). - synonym() properties are fully supported by all Query joining/ with_parent operations [ticket:598] - fixed very stupid bug when deleting items with many-to-many uselist=False relations - - remember all that stuff about polymorphic_union ? for + - remember all that stuff about polymorphic_union ? for joined table inheritance ? Funny thing... - You sort of don't need it for joined table inheritance, you + You sort of don't need it for joined table inheritance, you can just string all the tables together via outerjoin(). - The UNION still applies if concrete tables are involved, + The UNION still applies if concrete tables are involved, though (since nothing to join them on). - small fix to eager loading to better work with eager loads to polymorphic mappers that are using a straight "outerjoin" @@ -303,15 +2318,15 @@ - DynamicMetaData has been renamed to ThreadLocalMetaData. the DynamicMetaData name is deprecated and is an alias for ThreadLocalMetaData or a regular MetaData if threadlocal=False - - composite primary key is represented as a non-keyed set to allow for + - composite primary key is represented as a non-keyed set to allow for composite keys consisting of cols with the same name; occurs within a Join. helps inheritance scenarios formulate correct PK. - - improved ability to get the "correct" and most minimal set of primary key + - improved ability to get the "correct" and most minimal set of primary key columns from a join, equating foreign keys and otherwise equated columns. - this is also mostly to help inheritance scenarios formulate the best + this is also mostly to help inheritance scenarios formulate the best choice of primary key columns. [ticket:185] - added 'bind' argument to Sequence.create()/drop(), ColumnDefault.execute() - - columns can be overridden in a reflected table with a "key" + - columns can be overridden in a reflected table with a "key" attribute different than the column's name, including for primary key columns [ticket:650] - fixed "ambiguous column" result detection, when dupe col names exist @@ -322,7 +2337,7 @@ - MetaData and all SchemaItems are safe to use with pickle. slow table reflections can be dumped into a pickled file to be reused later. Just reconnect the engine to the metadata after unpickling. [ticket:619] - - added a mutex to QueuePool's "overflow" calculation to prevent a race + - added a mutex to QueuePool's "overflow" calculation to prevent a race condition that can bypass max_overflow - fixed grouping of compound selects to give correct results. will break on sqlite in some cases, but those cases were producing incorrect @@ -330,8 +2345,8 @@ [ticket:623] - fixed precedence of operators so that parenthesis are correctly applied [ticket:620] - - calling .in_() (i.e. with no arguments) will return - "CASE WHEN ( IS NULL) THEN NULL ELSE 0 END = 1)", so that + - calling .in_() (i.e. with no arguments) will return + "CASE WHEN ( IS NULL) THEN NULL ELSE 0 END = 1)", so that NULL or False is returned in all cases, rather than throwing an error [ticket:545] - fixed "where"/"from" criterion of select() to accept a unicode string @@ -339,9 +2354,9 @@ - added standalone distinct() function in addition to column.distinct() [ticket:558] - result.last_inserted_ids() should return a list that is identically - sized to the primary key constraint of the table. values that were + sized to the primary key constraint of the table. values that were "passively" created and not available via cursor.lastrowid will be None. - - long-identifier detection fixed to use > rather than >= for + - long-identifier detection fixed to use > rather than >= for max ident length [ticket:589] - fixed bug where selectable.corresponding_column(selectable.c.col) would not return selectable.c.col, if the selectable is a join @@ -356,7 +2371,7 @@ - oracle - datetime fixes: got subsecond TIMESTAMP to work [ticket:604], added OracleDate which supports types.Date with only year/month/day - - added dialect flag "auto_convert_lobs", defaults to True; will cause any + - added dialect flag "auto_convert_lobs", defaults to True; will cause any LOB objects detected in a result set to be forced into OracleBinary so that the LOB is read() automatically, if no typemap was present (i.e., if a textual execute() was issued). @@ -376,12 +2391,12 @@ being too old. - sqlite better handles datetime/date/time objects mixed and matched with various Date/Time/DateTime columns - - string PK column inserts dont get overwritten with OID [ticket:603] + - string PK column inserts dont get overwritten with OID [ticket:603] - mssql - fix port option handling for pyodbc [ticket:634] - now able to reflect start and increment values for identity columns - preliminary support for using scope_identity() with pyodbc - + 0.3.8 - engines - added detach() to Connection, allows underlying DBAPI connection @@ -394,7 +2409,7 @@ object. meaning, if you say someexpr.label('foo') == 5, it produces the correct "someexpr == 5". - _Label propigates "_hide_froms()" so that scalar selects - behave more properly with regards to FROM clause #574 + behave more properly with regards to FROM clause #574 - fix to long name generation when using oid_column as an order by (oids used heavily in mapper queries) - significant speed improvement to ResultProxy, pre-caches @@ -418,7 +2433,7 @@ join(['a', 'b', 'c']). - fixed bug in query.instances() that wouldnt handle more than on additional mapper or one additional column. - - "delete-orphan" no longer implies "delete". ongoing effort to + - "delete-orphan" no longer implies "delete". ongoing effort to separate the behavior of these two operations. - many-to-many relationships properly set the type of bind params for delete operations on the association table @@ -467,7 +2482,7 @@ 0.3.7 - engines - warnings module used for issuing warnings (instead of logging) - - cleanup of DBAPI import strategies across all engines + - cleanup of DBAPI import strategies across all engines [ticket:480] - refactoring of engine internals which reduces complexity, number of codepaths; places more state inside of ExecutionContext @@ -479,8 +2494,8 @@ - improved framework for auto-invalidation of connections that have lost their underlying database, via dialect-specific detection of exceptions corresponding to that database's disconnect - related error messages. Additionally, when a "connection no - longer open" condition is detected, the entire connection pool + related error messages. Additionally, when a "connection no + longer open" condition is detected, the entire connection pool is discarded and replaced with a new instance. #516 - the dialects within sqlalchemy.databases become a setuptools entry points. loading the built-in database dialects works the @@ -492,7 +2507,7 @@ - keys() of result set columns are not lowercased, come back exactly as they're expressed in cursor.description. note this causes colnames to be all caps in oracle. - - preliminary support for unicode table names, column names and + - preliminary support for unicode table names, column names and SQL statements added, for databases which can support them. Works with sqlite and postgres so far. Mysql *mostly* works except the has_table() function does not work. Reflection @@ -502,14 +2517,14 @@ of unicode situations that occur in db's such as MS-SQL to be better handled and allows subclassing of the Unicode datatype. [ticket:522] - - ClauseElements can be used in in_() clauses now, such as bind + - ClauseElements can be used in in_() clauses now, such as bind parameters, etc. #476 - reverse operators implemented for `CompareMixin` elements, allows expressions like "5 + somecolumn" etc. #474 - the "where" criterion of an update() and delete() now correlates embedded select() statements against the table being updated or deleted. this works the same as nested select() statement - correlation, and can be disabled via the correlate=False flag on + correlation, and can be disabled via the correlate=False flag on the embedded select(). - column labels are now generated in the compilation phase, which means their lengths are dialect-dependent. So on oracle a label @@ -517,14 +2532,14 @@ on postgres. Also, the true labelname is always attached as the accessor on the parent Selectable so theres no need to be aware of the "truncated" label names [ticket:512]. - - column label and bind param "truncation" also generate - deterministic names now, based on their ordering within the + - column label and bind param "truncation" also generate + deterministic names now, based on their ordering within the full statement being compiled. this means the same statement will produce the same string across application restarts and allowing DB query plan caching to work better. - the "mini" column labels generated when using subqueries, which are to work around glitchy SQLite behavior that doesnt understand - "foo.id" as equivalent to "id", are now only generated in the case + "foo.id" as equivalent to "id", are now only generated in the case that those named columns are selected from (part of [ticket:513]) - the label() method on ColumnElement will properly propigate the TypeEngine of the base element out to the label, including a label() @@ -553,15 +2568,15 @@ version. [ticket:541] - improved query.XXX_by(someprop=someinstance) querying to use similar methodology to with_parent, i.e. using the "lazy" clause - which prevents adding the remote instance's table to the SQL, + which prevents adding the remote instance's table to the SQL, thereby making more complex conditions possible [ticket:554] - added generative versions of aggregates, i.e. sum(), avg(), etc. - to query. used via query.apply_max(), apply_sum(), etc. + to query. used via query.apply_max(), apply_sum(), etc. #552 - - fix to using distinct() or distinct=True in combination with + - fix to using distinct() or distinct=True in combination with join() and similar - - corresponding to label/bindparam name generation, eager loaders - generate deterministic names for the aliases they create using + - corresponding to label/bindparam name generation, eager loaders + generate deterministic names for the aliases they create using md5 hashes. - improved/fixed custom collection classes when giving it "set"/ "sets.Set" classes or subclasses (was still looking for append() @@ -587,7 +2602,7 @@ - oracle: - small fix to allow successive compiles of the same SELECT object which features LIMIT/OFFSET. oracle dialect needs to modify - the object to have ROW_NUMBER OVER and wasn't performing + the object to have ROW_NUMBER OVER and wasn't performing the full series of steps on successive compiles. - mysql - support for SSL arguments given as inline within URL query string, @@ -598,20 +2613,20 @@ with MySQL5 but should work with 4.1 series as well. (#557) - extensions - big fix to AssociationProxy so that multiple AssociationProxy - objects can be associated with a single association collection. + objects can be associated with a single association collection. - assign_mapper names methods according to their keys (i.e. __name__) #551 - mssql - - pyodbc is now the preferred DB-API for MSSQL, and if no module is + - pyodbc is now the preferred DB-API for MSSQL, and if no module is specifically requested, will be loaded first on a module probe. - - The @@SCOPE_IDENTITY is now used instead of @@IDENTITY. This + - The @@SCOPE_IDENTITY is now used instead of @@IDENTITY. This behavior may be overridden with the engine_connect - "use_scope_identity" keyword parameter, which may also be specified + "use_scope_identity" keyword parameter, which may also be specified in the dburi. - + 0.3.6 - sql: - bindparam() names are now repeatable! specify two @@ -619,9 +2634,9 @@ and the key will be shared. proper positional/named args translate at compile time. for the old behavior of "aliasing" bind parameters with conflicting names, specify "unique=True" - this option is - still used internally for all the auto-genererated (value-based) - bind parameters. - + still used internally for all the auto-genererated (value-based) + bind parameters. + - slightly better support for bind params as column clauses, either via bindparam() or via literal(), i.e. select([literal('foo')]) @@ -630,7 +2645,7 @@ identical to MetaData except engine_or_url param is required. DynamicMetaData is the same and provides thread-local connections be default. - + - exists() becomes useable as a standalone selectable, not just in a WHERE clause, i.e. exists([columns], criterion).select() @@ -677,7 +2692,7 @@ - the full featureset of the SelectResults extension has been merged into a new set of methods available off of Query. These methods all provide "generative" behavior, whereby the Query is copied - and a new one returned with additional criterion added. + and a new one returned with additional criterion added. The new methods include: filter() - applies select criterion to the query @@ -686,17 +2701,17 @@ join() - join to a property (or across a list of properties) outerjoin() - like join() but uses LEFT OUTER JOIN limit()/offset() - apply LIMIT/OFFSET - range-based access which applies limit/offset: + range-based access which applies limit/offset: session.query(Foo)[3:5] distinct() - apply DISTINCT list() - evaluate the criterion and return results - + no incompatible changes have been made to Query's API and no methods have been deprecated. Existing methods like select(), select_by(), get(), get_by() all execute the query at once and return results like they always did. join_to()/join_via() are still there although the generative join()/outerjoin() methods are easier to use. - + - the return value for multiple mappers used with instances() now returns a cartesian product of the requested list of mappers, represented as a list of tuples. this corresponds to the documented @@ -725,7 +2740,7 @@ the columns placed in the "order by" of Query.select(), that you have explicitly named them in your criterion (i.e. you cant rely on the eager loader adding them in for you) - + - added a handy multi-use "identity_key()" method to Session, allowing the generation of identity keys for primary key values, instances, and rows, courtesy Daniel Miller @@ -735,13 +2750,13 @@ - added "refresh-expire" cascade [ticket:492]. allows refresh() and expire() calls to propigate along relationships. - + - more fixes to polymorphic relations, involving proper lazy-clause - generation on many-to-one relationships to polymorphic mappers + generation on many-to-one relationships to polymorphic mappers [ticket:493]. also fixes to detection of "direction", more specific targeting of columns that belong to the polymorphic union vs. those that dont. - + - some fixes to relationship calcs when using "viewonly=True" to pull in other tables into the join condition which arent parent of the relationship's parent/child mappings @@ -749,7 +2764,7 @@ - flush fixes on cyclical-referential relationships that contain references to other instances outside of the cyclical chain, when some of the objects in the cycle are not actually part of the flush - + - put an aggressive check for "flushing object A with a collection of B's, but you put a C in the collection" error condition - **even if C is a subclass of B**, unless B's mapper loads polymorphically. @@ -767,25 +2782,25 @@ like the rest of the SelectResults methods [ticket:472]. But you're going to just use Query now anyway. - - query() method is added by assignmapper. this helps with + - query() method is added by assignmapper. this helps with navigating to all the new generative methods on Query. - ms-sql: - - removed seconds input on DATE column types (probably + - removed seconds input on DATE column types (probably should remove the time altogether) - null values in float fields no longer raise errors - LIMIT with OFFSET now raises an error (MS-SQL has no OFFSET support) - - added an facility to use the MSSQL type VARCHAR(max) instead of TEXT - for large unsized string fields. Use the new "text_as_varchar" to + - added an facility to use the MSSQL type VARCHAR(max) instead of TEXT + for large unsized string fields. Use the new "text_as_varchar" to turn it on. [ticket:509] - - ORDER BY clauses without a LIMIT are now stripped in subqueries, as + - ORDER BY clauses without a LIMIT are now stripped in subqueries, as MS-SQL forbids this usage - - cleanup of module importing code; specifiable DB-API module; more + - cleanup of module importing code; specifiable DB-API module; more explicit ordering of module preferences. [ticket:480] - oracle: @@ -799,13 +2814,13 @@ it improperly propigated bad types. - mysql: - - added a catchall **kwargs to MSString, to help reflection of + - added a catchall **kwargs to MSString, to help reflection of obscure types (like "varchar() binary" in MS 4.0) - - added explicit MSTimeStamp type which takes effect when using + - added explicit MSTimeStamp type which takes effect when using types.TIMESTAMP. - + 0.3.5 - sql: - the value of "case_sensitive" defaults to True now, regardless of the @@ -841,7 +2856,7 @@ - issues a log warning when a related table cant be reflected due to certain permission errors [ticket:363] - mysql: - - fix to reflection on older DB's that might return array() type for + - fix to reflection on older DB's that might return array() type for "show variables like" statements - postgres: - better reflection of sequences for alternate-schema Tables [ticket:442] @@ -1043,7 +3058,7 @@ - added example/docs for dealing with large collections - added object_session() method to sqlalchemy namespace - fixed QueuePool bug whereby its better able to reconnect to a database -that was not reachable (thanks to Sébastien Lelong), also fixed dispose() +that was not reachable (thanks to Sébastien Lelong), also fixed dispose() method - patch that makes MySQL rowcount work correctly! [ticket:396] - fix to MySQL catch of 2006/2014 errors to properly re-raise OperationalError @@ -1054,8 +3069,8 @@ exception errors, will also prevent transactions getting rolled back accidentally in all DBs [ticket:387] - major speed enhancements vs. 0.3.1, to bring speed -back to 0.2.8 levels - - made conditional dozens of debug log calls that were +back to 0.2.8 levels + - made conditional dozens of debug log calls that were time-intensive to generate log messages - fixed bug in cascade rules whereby the entire object graph could be unnecessarily cascaded on the save/update cascade @@ -1072,13 +3087,13 @@ fixes [ticket:388] - assign_mapper in assignmapper extension returns the created mapper [changeset:2110] - added label() function to Select class, when scalar=True is used -to create a scalar subquery +to create a scalar subquery i.e. "select x, y, (select max(foo) from table) AS foomax from table" - added onupdate and ondelete keyword arguments to ForeignKey; propigate to underlying ForeignKeyConstraint if present. (dont propigate in the other direction, however) - fix to session.update() to preserve "dirty" status of incoming object -- sending a selectable to an IN via the in_() function no longer creates +- sending a selectable to an IN via the in_() function no longer creates a "union" out of multiple selects; only one selectable to a the in_() function is allowed now (make a union yourself if union is needed) - improved support for disabling save-update cascade via cascade="none" etc. @@ -1103,7 +3118,7 @@ invalid kwargs in relation to the selected dialect/pool/engine configuration. - fix to postgres sequence quoting when using schemas - ORM: - the "delete" cascade will load in all child objects, if they were not -loaded already. this can be turned off (i.e. the old behavior) by setting +loaded already. this can be turned off (i.e. the old behavior) by setting passive_deletes=True on a relation(). - adjustments to reworked eager query generation to not fail on circular eager-loaded relationships (like backrefs) @@ -1111,7 +3126,7 @@ eager-loaded relationships (like backrefs) instruct the Query whether or not to use "nesting" when producing a LIMIT query. - fixed bug in circular dependency sorting at flush time; if object A -contained a cyclical many-to-one relationship to object B, and object B +contained a cyclical many-to-one relationship to object B, and object B was just attached to object A, *but* object B itself wasnt changed, the many-to-one synchronize of B's primary key attribute to A's foreign key attribute wouldnt occur. [ticket:360] @@ -1120,10 +3135,10 @@ on selectresults [ticket:325] - added an assertion within the "cascade" step of ORM relationships to check that the class of object attached to a parent object is appropriate (i.e. if A.items stores B objects, raise an error if a C is appended to A.items) - - new extension sqlalchemy.ext.associationproxy, provides transparent -"association object" mappings. new example + - new extension sqlalchemy.ext.associationproxy, provides transparent +"association object" mappings. new example examples/association/proxied_association.py illustrates. - - improvement to single table inheritance to load full hierarchies beneath + - improvement to single table inheritance to load full hierarchies beneath the target class - fix to subtle condition in topological sort where a node could appear twice, for [ticket:362] @@ -1148,21 +3163,21 @@ the instance is an "orphan" only if its not attached to *any* of those parents - Specific Databases: - SQLite: - sqlite boolean datatype converts False/True to 0/1 by default - - fixes to Date/Time (SLDate/SLTime) types; works as good as postgres + - fixes to Date/Time (SLDate/SLTime) types; works as good as postgres now [ticket:335] - - MS-SQL: - - fixes bug 261 (table reflection broken for MS-SQL case-sensitive + - MS-SQL: + - fixes bug 261 (table reflection broken for MS-SQL case-sensitive databases) - can now specify port for pymssql - - introduces new "auto_identity_insert" option for auto-switching - between "SET IDENTITY_INSERT" mode when values specified for IDENTITY columns + - introduces new "auto_identity_insert" option for auto-switching + between "SET IDENTITY_INSERT" mode when values specified for IDENTITY columns - now supports multi-column foreign keys - fix to reflecting date/datetime columns - NCHAR and NVARCHAR type support added - Oracle: - Oracle has experimental support for cx_Oracle.TIMESTAMP, which requires a setinputsizes() call on the cursor that is now enabled via the - 'auto_setinputsizes' flag to the oracle dialect. + 'auto_setinputsizes' flag to the oracle dialect. - Firebird: - aliases do not use "AS" - correctly raises NoSuchTableError when reflecting non-existent table @@ -1172,13 +3187,13 @@ the instance is an "orphan" only if its not attached to *any* of those parents useage, greater emphasis on explicitness - the "primary_key" attribute of Table and other selectables becomes a setlike ColumnCollection object; is ordered but not numerically - indexed. a comparison clause between two pks that are derived from the - same underlying tables (i.e. such as two Alias objects) can be generated + indexed. a comparison clause between two pks that are derived from the + same underlying tables (i.e. such as two Alias objects) can be generated via table1.primary_key==table2.primary_key - ForeignKey(Constraint) supports "use_alter=True", to create/drop a foreign key via ALTER. this allows circular foreign key relationships to be set up. - append_item() methods removed from Table and Column; preferably - construct Table/Column/related objects inline, but if needed use + construct Table/Column/related objects inline, but if needed use append_column(), append_foreign_key(), append_constraint(), etc. - table.create() no longer returns the Table object, instead has no return value. the usual case is that tables are created via metadata, @@ -1186,7 +3201,7 @@ the instance is an "orphan" only if its not attached to *any* of those parents - added UniqueConstraint (goes at Table level), CheckConstraint (goes at Table or Column level). - index=False/unique=True on Column now creates a UniqueConstraint, - index=True/unique=False creates a plain Index, + index=True/unique=False creates a plain Index, index=True/unique=True on Column creates a unique Index. 'index' and 'unique' keyword arguments to column are now boolean only; for explcit names and groupings of indexes or unique constraints, use the @@ -1201,13 +3216,13 @@ the instance is an "orphan" only if its not attached to *any* of those parents get both the reflected and the programmatic column doubled up - the "foreign_key" attribute on Column and ColumnElement in general is deprecated, in favor of the "foreign_keys" list/set-based attribute, - which takes into account multiple foreign keys on one column. + which takes into account multiple foreign keys on one column. "foreign_key" will return the first element in the "foreign_keys" list/set or None if the list is empty. - Connections/Pooling/Execution: - connection pool tracks open cursors and automatically closes them if connection is returned to pool with cursors still opened. Can be - affected by options which cause it to raise an error instead, or to + affected by options which cause it to raise an error instead, or to do nothing. fixes issues with MySQL, others - fixed bug where Connection wouldnt lose its Transaction after commit/rollback @@ -1221,16 +3236,16 @@ the instance is an "orphan" only if its not attached to *any* of those parents - changed "for_update" parameter to accept False/True/"nowait" and "read", the latter two of which are interpreted only by Oracle and Mysql [ticket:292] - - added extract() function to sql dialect + - added extract() function to sql dialect (SELECT extract(field FROM expr)) - BooleanExpression includes new "negate" argument to specify the appropriate negation operator if one is available. - calling a negation on an "IN" or "IS" clause will result in - "NOT IN", "IS NOT" (as opposed to NOT (x IN y)). + "NOT IN", "IS NOT" (as opposed to NOT (x IN y)). - Function objects know what to do in a FROM clause now. their behavior should be the same, except now you can also do things like - select(['*'], from_obj=[func.my_function()]) to get multiple - columns from the result, or even use sql.column() constructs to name the + select(['*'], from_obj=[func.my_function()]) to get multiple + columns from the result, or even use sql.column() constructs to name the return columns [ticket:172] - ORM: - attribute tracking modified to be more intelligent about detecting @@ -1239,27 +3254,27 @@ the instance is an "orphan" only if its not attached to *any* of those parents including the addition of a MutableType mixin which is implemented by PickleType. unit-of-work now tracks the "dirty" list as an expression of all persistent objects where the attribute manager detects changes. - The basic issue thats fixed is detecting changes on PickleType + The basic issue thats fixed is detecting changes on PickleType objects, but also generalizes type handling and "modified" object checking to be more complete and extensible. - a wide refactoring to "attribute loader" and "options" architectures. ColumnProperty and PropertyLoader define their loading behaivor via switchable - "strategies", and MapperOptions no longer use mapper/property copying + "strategies", and MapperOptions no longer use mapper/property copying in order to function; they are instead propigated via QueryContext and SelectionContext objects at query/instances time. All of the internal copying of mappers and properties that was used to handle inheritance as well as options() has been removed; the structure of mappers and properties is much simpler than before and is clearly laid out in the new 'interfaces' module. - - related to the mapper/property overhaul, internal refactoring to - mapper instances() method to use a SelectionContext object to track + - related to the mapper/property overhaul, internal refactoring to + mapper instances() method to use a SelectionContext object to track state during the operation. SLIGHT API BREAKAGE: the append_result() and populate_instances() methods on MapperExtension have a slightly different method signature - now as a result of the change; hoping that these methods are not + now as a result of the change; hoping that these methods are not in widespread use as of yet. - - instances() method moved to Query now, backwards-compatible - version remains on Mapper. + - instances() method moved to Query now, backwards-compatible + version remains on Mapper. - added contains_eager() MapperOption, used in conjunction with instances() to specify properties that should be eagerly loaded from the result set, using their plain column names by default, or translated @@ -1271,29 +3286,29 @@ the instance is an "orphan" only if its not attached to *any* of those parents statements in order of tables across all inherited classes [ticket:321] - added an automatic "row switch" feature to mapping, which will - detect a pending instance/deleted instance pair with the same + detect a pending instance/deleted instance pair with the same identity key and convert the INSERT/DELETE to a single UPDATE - - "association" mappings simplified to take advantage of + - "association" mappings simplified to take advantage of automatic "row switch" feature - "custom list classes" is now implemented via the "collection_class" keyword argument to relation(). the old way still works but is deprecated [ticket:212] - added "viewonly" flag to relation(), allows construction of relations that have no effect on the flush() process. - - added "lockmode" argument to base Query select/get functions, - including "with_lockmode" function to get a Query copy that has - a default locking mode. Will translate "read"/"update" + - added "lockmode" argument to base Query select/get functions, + including "with_lockmode" function to get a Query copy that has + a default locking mode. Will translate "read"/"update" arguments into a for_update argument on the select side. [ticket:292] - implemented "version check" logic in Query/Mapper, used when version_id_col is in effect and query.with_lockmode() is used to get() an instance thats already loaded - - post_update behavior improved; does a better job at not + - post_update behavior improved; does a better job at not updating too many rows, updates only required columns [ticket:208] - adjustments to eager loading so that its "eager chain" is kept separate from the normal mapper setup, thereby - preventing conflicts with lazy loader operation, fixes + preventing conflicts with lazy loader operation, fixes [ticket:308] - fix to deferred group loading - session.flush() wont close a connection it opened [ticket:346] @@ -1309,7 +3324,7 @@ the instance is an "orphan" only if its not attached to *any* of those parents outerjoins in queries without the main table getting added twice. [ticket:315] - eagerloading is adjusted to more thoughtfully attach its LEFT OUTER JOINs - to the given query, looking for custom "FROM" clauses that may have + to the given query, looking for custom "FROM" clauses that may have already been set up. - added join_to and outerjoin_to transformative methods to SelectResults, to build up join/outerjoin conditions based on property names. also @@ -1323,15 +3338,15 @@ to 'create_engine', or custom creation function via 'creator' function to 'create_engine'. - added "recycle" argument to Pool, is "pool_recycle" on create_engine, defaults to 3600 seconds; connections after this age will be closed and -replaced with a new one, to handle db's that automatically close +replaced with a new one, to handle db's that automatically close stale connections [ticket:274] -- changed "invalidate" semantics with pooled connection; will -instruct the underlying connection record to reconnect the next +- changed "invalidate" semantics with pooled connection; will +instruct the underlying connection record to reconnect the next time its called. "invalidate" will also automatically be called if any error is thrown in the underlying call to connection.cursor(). this will hopefully allow the connection pool to reconnect to a database that had been stopped and started without restarting -the connecting application [ticket:121] +the connecting application [ticket:121] - eesh ! the tutorial doctest was broken for quite some time. - add_property() method on mapper does a "compile all mappers" step in case the given property references a non-compiled mapper @@ -1339,17 +3354,17 @@ step in case the given property references a non-compiled mapper - [ticket:277] check for pg sequence already existing before create - if a contextual session is established via MapperExtension.get_session (as it is using the sessioncontext plugin, etc), a lazy load operation -will use that session by default if the parent object is not +will use that session by default if the parent object is not persistent with a session already. -- lazy loads will not fire off for an object that does not have a -database identity (why? +- lazy loads will not fire off for an object that does not have a +database identity (why? see http://www.sqlalchemy.org/trac/wiki/WhyDontForeignKeysLoadData) - unit-of-work does a better check for "orphaned" objects that are -part of a "delete-orphan" cascade, for certain conditions where the +part of a "delete-orphan" cascade, for certain conditions where the parent isnt available to cascade from. - mappers can tell if one of their objects is an "orphan" based on interactions with the attribute package. this check is based -on a status flag maintained for each relationship +on a status flag maintained for each relationship when objects are attached and detached from each other. - it is now invalid to declare a self-referential relationship with "delete-orphan" (as the abovementioned check would make them impossible @@ -1365,12 +3380,12 @@ with use_information_schema=True argument to create_engine - added case_sensitive argument to MetaData, Table, Column, determines itself automatically based on if a parent schemaitem has a non-None setting for the flag, or if not, then whether the identifier name is all lower -case or not. when set to True, quoting is applied to identifiers with mixed or -uppercase identifiers. quoting is also applied automatically in all cases to -identifiers that are known to be reserved words or contain other non-standard -characters. various database dialects can override all of this behavior, but -currently they are all using the default behavior. tested with postgres, mysql, -sqlite, oracle. needs more testing with firebird, ms-sql. part of the ongoing +case or not. when set to True, quoting is applied to identifiers with mixed or +uppercase identifiers. quoting is also applied automatically in all cases to +identifiers that are known to be reserved words or contain other non-standard +characters. various database dialects can override all of this behavior, but +currently they are all using the default behavior. tested with postgres, mysql, +sqlite, oracle. needs more testing with firebird, ms-sql. part of the ongoing work with [ticket:155] - unit tests updated to run without any pysqlite installed; pool test uses a mock DBAPI @@ -1390,7 +3405,7 @@ count() [ticket:287] 0.2.7 - quoting facilities set up so that database-specific quoting can be turned on for individual table, schema, and column identifiers when -used in all queries/creates/drops. Enabled via "quote=True" in +used in all queries/creates/drops. Enabled via "quote=True" in Table or Column, as well as "quote_schema=True" in Table. Thanks to Aaron Spike for his excellent efforts. - assignmapper was setting is_primary=True, causing all sorts of mayhem @@ -1400,10 +3415,10 @@ primary key columns are null (i.e. when mapping to outer joins etc) - modifcation to unitofwork to not maintain ordering within the "new" list or within the UOWTask "objects" list; instead, new objects are tagged with an ordering identifier as they are registered as new -with the session, and the INSERT statements are then sorted within the +with the session, and the INSERT statements are then sorted within the mapper save_obj. the INSERT ordering has basically been pushed all -the way to the end of the flush cycle. that way the various sorts and -organizations occuring within UOWTask (particularly the circular task +the way to the end of the flush cycle. that way the various sorts and +organizations occuring within UOWTask (particularly the circular task sort) dont have to worry about maintaining order (which they werent anyway) - fixed reflection of foreign keys to autoload the referenced table if it was not loaded already @@ -1414,15 +3429,15 @@ to backrefs by default. specifying a backref() will override this behavior. - better check for ambiguous join conditions in sql.Join; propigates to a better error message in PropertyLoader (i.e. relation()/backref()) for when the join condition can't be reasonably determined. -- sqlite creates ForeignKeyConstraint objects properly upon table +- sqlite creates ForeignKeyConstraint objects properly upon table reflection. -- adjustments to pool stemming from changes made for [ticket:224]. +- adjustments to pool stemming from changes made for [ticket:224]. overflow counter should only be decremented if the connection actually succeeded. added a test script to attempt testing this. - fixed mysql reflection of default values to be PassiveDefault -- added reflected 'tinyint', 'mediumint' type to MS-SQL [ticket:263], +- added reflected 'tinyint', 'mediumint' type to MS-SQL [ticket:263], [ticket:264] -- SingletonThreadPool has a size and does a cleanup pass, so that +- SingletonThreadPool has a size and does a cleanup pass, so that only a given number of thread-local connections stay around (needed for sqlite applications that dispose of threads en masse) - fixed small pickle bug(s) with lazy loaders [ticket:265] [ticket:267] @@ -1434,14 +3449,14 @@ return an array instead of string for SHOW CREATE TABLE call - fixed ms-sql connect() to work with adodbapi - added "nowait" flag to Select() - inheritance check uses issubclass() instead of direct __mro__ check -to make sure class A inherits from B, allowing mapper inheritance to more +to make sure class A inherits from B, allowing mapper inheritance to more flexibly correspond to class inheritance [ticket:271] - SelectResults will use a subselect, when calling an aggregate (i.e. max, min, etc.) on a SelectResults that has an ORDER BY clause [ticket:252] - fixes to types so that database-specific types more easily used; fixes to mysql text types to work with this methodology -[ticket:269] +[ticket:269] - some fixes to sqlite date type organization - added MSTinyInteger to MS-SQL [ticket:263] @@ -1453,7 +3468,7 @@ Existing methods of primary/foreign key creation have not been changed but use these new objects behind the scenes. table creation and reflection is now more table oriented rather than column oriented. [ticket:76] -- overhaul to MapperExtension calling scheme, wasnt working very well +- overhaul to MapperExtension calling scheme, wasnt working very well previously - tweaks to ActiveMapper, supports self-referential relationships - slight rearrangement to objectstore (in activemapper/threadlocal) @@ -1470,28 +3485,28 @@ this also adds them to activemapper - connection exceptions wrapped in DBAPIError - ActiveMapper now supports autoloading column definitions from the database if you supply a __autoload__ = True attribute in your -mapping inner-class. Currently this does not support reflecting +mapping inner-class. Currently this does not support reflecting any relationships. -- deferred column load could screw up the connection status in +- deferred column load could screw up the connection status in a flush() under some circumstances, this was fixed - expunge() was not working with cascade, fixed. - potential endless loop in cascading operations fixed. -- added "synonym()" function, applied to properties to have a +- added "synonym()" function, applied to properties to have a propname the same as another, for the purposes of overriding props and allowing the original propname to be accessible in select_by(). - fix to typing in clause construction which specifically helps type issues with polymorphic_union (CAST/ColumnClause propigates its type to proxy columns) -- mapper compilation work ongoing, someday it'll work....moved +- mapper compilation work ongoing, someday it'll work....moved around the initialization of MapperProperty objects to be after all mappers are created to better handle circular compilations. -do_init() method is called on all properties now which are more +do_init() method is called on all properties now which are more aware of their "inherited" status if so. - eager loads explicitly disallowed on self-referential relationships, or relationships to an inheriting mapper (which is also self-referential) -- reduced bind param size in query._get to appease the picky oracle +- reduced bind param size in query._get to appease the picky oracle [ticket:244] -- added 'checkfirst' argument to table.create()/table.drop(), as +- added 'checkfirst' argument to table.create()/table.drop(), as well as table.exists() [ticket:234] - some other ongoing fixes to inheritance [ticket:245] - attribute/backref/orphan/history-tracking tweaks as usual... @@ -1499,16 +3514,16 @@ well as table.exists() [ticket:234] 0.2.5 - fixed endless loop bug in select_by(), if the traversal hit two mappers that referenced each other -- upgraded all unittests to insert './lib/' into sys.path, +- upgraded all unittests to insert './lib/' into sys.path, working around new setuptools PYTHONPATH-killing behavior - further fixes with attributes/dependencies/etc.... - improved error handling for when DynamicMetaData is not connected - MS-SQL support largely working (tested with pymssql) -- ordering of UPDATE and DELETE statements within groups is now +- ordering of UPDATE and DELETE statements within groups is now in order of primary key values, for more deterministic ordering - after_insert/delete/update mapper extensions now called per object, not per-object-per-table -- further fixes/refactorings to mapper compilation +- further fixes/refactorings to mapper compilation 0.2.4 - try/except when the mapper sets init.__name__ on a mapped class, @@ -1520,7 +3535,7 @@ to be in a Session to do the operation; whereas before the operation would just return a blank list or None, it now raises an exception. - Session.update() is slightly more lenient if the session to which the given object was formerly attached to was garbage collected; -otherwise still requires you explicitly remove the instance from +otherwise still requires you explicitly remove the instance from the previous Session. - fixes to mapper compilation, checking for more error conditions - small fix to eager loading combined with ordering/limit/offset @@ -1533,10 +3548,10 @@ properly saving - when QueuePool times out it raises a TimeoutError instead of erroneously making another connection - Queue.Queue usage in pool has been replaced with a locally -modified version (works in py2.3/2.4!) that uses a threading.RLock -for a mutex. this is to fix a reported case where a ConnectionFairy's -__del__() method got called within the Queue's get() method, which -then returns its connection to the Queue via the the put() method, +modified version (works in py2.3/2.4!) that uses a threading.RLock +for a mutex. this is to fix a reported case where a ConnectionFairy's +__del__() method got called within the Queue's get() method, which +then returns its connection to the Queue via the the put() method, causing a reentrant hang unless threading.RLock is used. - postgres will not place SERIAL keyword on a primary key column if it has a foreign key constraint @@ -1557,13 +3572,13 @@ when backrefs were in use - the attribute instrumentation module has been completely rewritten; its now a large degree simpler and clearer, slightly faster. the "history" of an attribute is no longer micromanaged with each change and is -instead part of a "CommittedState" object created when the +instead part of a "CommittedState" object created when the instance is first loaded. HistoryArraySet is gone, the behavior of list attributes is now more open ended (i.e. theyre not sets anymore). - py2.4 "set" construct used internally, falls back to sets.Set when "set" not available/ordering is needed. -- fix to transaction control, so that repeated rollback() calls -dont fail (was failing pretty badly when flush() would raise +- fix to transaction control, so that repeated rollback() calls +dont fail (was failing pretty badly when flush() would raise an exception in a larger try/except transaction block) - "foreignkey" argument to relation() can also be a list. fixed auto-foreignkey detection [ticket:151] @@ -1584,16 +3599,16 @@ refactorings __doc__ from the original class - fixed small bug in selectresult.py regarding mapper extension [ticket:200] -- small tweak to cascade_mappers, not very strongly supported +- small tweak to cascade_mappers, not very strongly supported function at the moment - some fixes to between(), column.between() to propigate typing information better [ticket:202] -- if an object fails to be constructed, is not added to the +- if an object fails to be constructed, is not added to the session [ticket:203] - CAST function has been made into its own clause object with its own compilation function in ansicompiler; allows MySQL to silently ignore most CAST calls since MySQL -seems to only support the standard CAST syntax with Date types. +seems to only support the standard CAST syntax with Date types. MySQL-compatible CAST support for strings, ints, etc. a TODO 0.2.2 @@ -1605,16 +3620,16 @@ more unit tests - fix to docs, removed incorrect info that close() is unsafe to use with threadlocal strategy (its totally safe !) - create_engine() can take URLs as string or unicode [ticket:188] -- firebird support partially completed; +- firebird support partially completed; thanks to James Ralston and Brad Clements for their efforts. - Oracle url translation was broken, fixed, will feed host/port/sid -into cx_oracle makedsn() if 'database' field is present, else uses +into cx_oracle makedsn() if 'database' field is present, else uses straight TNS name from the 'host' field - fix to using unicode criterion for query.get()/query.load() -- count() function on selectables now uses table primary key or +- count() function on selectables now uses table primary key or first column instead of "1" for criterion, also uses label "rowcount" -instead of "count". -- got rudimental "mapping to multiple tables" functionality cleaned up, +instead of "count". +- got rudimental "mapping to multiple tables" functionality cleaned up, more correctly documented - restored global_connect() function, attaches to a DynamicMetaData instance called "default_metadata". leaving MetaData arg to Table @@ -1626,11 +3641,11 @@ out will use the default metadata. 0.2.1 - "pool" argument to create_engine() properly propigates - fixes to URL, raises exception if not parsed, does not pass blank -fields along to the DB connect string (a string such as +fields along to the DB connect string (a string such as user:host@/db was breaking on postgres) - small fixes to Mapper when it inserts and tries to get new primary key values back -- rewrote half of TLEngine, the ComposedSQLEngine used with +- rewrote half of TLEngine, the ComposedSQLEngine used with 'strategy="threadlocal"'. it now properly implements engine.begin()/ engine.commit(), which nest fully with connection.begin()/trans.commit(). added about six unittests. @@ -1638,7 +3653,7 @@ added about six unittests. unittest which was supposed to check for this was also silently missing it. fixed unittest to ensure that ConnectionFairy properly falls out of scope. -- placeholder dispose() method added to SingletonThreadPool, doesnt +- placeholder dispose() method added to SingletonThreadPool, doesnt do anything yet - rollback() is automatically called when an exception is raised, but only if theres no transaction in process (i.e. works more like @@ -1657,28 +3672,28 @@ driver://user:password@host:port/database - total rewrite of connection-scoping methodology, Connection objects can now execute clause elements directly, added explicit "close" as well as support throughout Engine/ORM to handle closing properly, -no longer relying upon __del__ internally to return connections +no longer relying upon __del__ internally to return connections to the pool [ticket:152]. - overhaul to Session interface and scoping. uses hibernate-style methods, including query(class), save(), save_or_update(), etc. no threadlocal scope is installed by default. Provides a binding interface to specific Engines and/or Connections so that underlying Schema objects do not need to be bound to an Engine. Added a basic -SessionTransaction object that can simplistically aggregate transactions +SessionTransaction object that can simplistically aggregate transactions across multiple engines. - overhaul to mapper's dependency and "cascade" behavior; dependency logic factored out of properties.py into a separate module "dependency.py". -"cascade" behavior is now explicitly controllable, proper implementation -of "delete", "delete-orphan", etc. dependency system can now determine at -flush time if a child object has a parent or not so that it makes better +"cascade" behavior is now explicitly controllable, proper implementation +of "delete", "delete-orphan", etc. dependency system can now determine at +flush time if a child object has a parent or not so that it makes better decisions on how that child should be updated in the DB with regards to deletes. - overhaul to Schema to build upon MetaData object instead of an Engine. Entire SQL/Schema system can be used with no Engines whatsoever, executed -solely by an explicit Connection object. the "bound" methodlogy exists via the +solely by an explicit Connection object. the "bound" methodlogy exists via the BoundMetaData for schema objects. ProxyEngine is generally not needed anymore and is replaced by DynamicMetaData. - true polymorphic behavior implemented, fixes [ticket:167] -- "oid" system has been totally moved into compile-time behavior; +- "oid" system has been totally moved into compile-time behavior; if they are used in an order_by where they are not available, the order_by doesnt get compiled, fixes [ticket:147] - overhaul to packaging; "mapping" is now "orm", "objectstore" is now @@ -1686,7 +3701,7 @@ doesnt get compiled, fixes [ticket:147] "threadlocal" mod if used - mods now called in via "import ". extensions favored over mods as mods are globally-monkeypatching -- fix to add_property so that it propigates properties to inheriting +- fix to add_property so that it propigates properties to inheriting mappers [ticket:154] - backrefs create themselves against primary mapper of its originating property, priamry/secondary join arguments can be specified to override. @@ -1696,10 +3711,10 @@ helps their usage with polymorphic mappers - improvements and fixes to topological sort algorithm, as well as more unit tests - tutorial page added to docs which also can be run with a custom doctest -runner to ensure its properly working. docs generally overhauled to +runner to ensure its properly working. docs generally overhauled to deal with new code patterns - many more fixes, refactorings. -- migration guide is available on the Wiki at +- migration guide is available on the Wiki at http://www.sqlalchemy.org/trac/wiki/02Migration 0.1.7 @@ -1718,7 +3733,7 @@ http://www.sqlalchemy.org/trac/wiki/02Migration - fix for parenthesis to work correctly with subqueries in INSERT/UPDATE - HistoryArraySet gets extend() method - fixed lazyload support for other comparison operators besides = -- lazyload fix where two comparisons in the join condition point to the +- lazyload fix where two comparisons in the join condition point to the samem column - added "construct_new" flag to mapper, will use __new__ to create instances instead of __init__ (standard in 0.2) @@ -1741,7 +3756,7 @@ core functionality, using the function "install_mods(*modnames)". return generators that turn ranges into LIMIT/OFFSET queries (Jonas Borgstr?- factored out querying capabilities of Mapper into a separate Query object which is Session-centric. this improves the performance of mapper.using(session) and makes other things possible. -- objectstore/Session refactored, the official way to save objects is now +- objectstore/Session refactored, the official way to save objects is now via the flush() method. The begin/commit functionality of Session is factored into LegacySession which is still established as the default behavior, until the 0.2 series. @@ -1787,7 +3802,7 @@ modify the population of object attributes. this method can call the populate_instance() method on another mapper to proxy the attribute population from one mapper to another; some row translation logic is also built in to help with this. -- fixed Oracle8-compatibility "use_ansi" flag which converts JOINs to +- fixed Oracle8-compatibility "use_ansi" flag which converts JOINs to comparisons with the = and (+) operators, passes basic unittests - tweaks to Oracle LIMIT/OFFSET support - Oracle reflection uses ALL_** views instead of USER_** to get larger @@ -1796,14 +3811,14 @@ list of stuff to reflect from - objectstore.commit(obj1, obj2,...) adds an extra step to seek out private relations on properties and delete child objects, even though its not a global commit -- lots and lots of fixes to mappers which use inheritance, strengthened the +- lots and lots of fixes to mappers which use inheritance, strengthened the concept of relations on a mapper being made towards the "local" table for that mapper, not the tables it inherits. allows more complex compositional patterns to work with lazy/eager loading. -- added support for mappers to inherit from others based on the same table, +- added support for mappers to inherit from others based on the same table, just specify the same table as that of both parent/child mapper. -- some minor speed improvements to the attributes system with regards to -instantiating and populating new objects. +- some minor speed improvements to the attributes system with regards to +instantiating and populating new objects. - fixed MySQL binary unit test - INSERTs can receive clause elements as VALUES arguments, not just literal values @@ -1831,7 +3846,7 @@ keynames they are now generated from a column "label" in all relevant cases to take advantage of excess-name-length rules, and checks for a peculiar collision against a column named the same as "tablename_colname" added - major overhaul to unit of work documentation, other documentation sections. -- fixed attributes bug where if an object is committed, its lazy-loaded list got +- fixed attributes bug where if an object is committed, its lazy-loaded list got blown away if it hadnt been loaded - added unique_connection() method to engine, connection pool to return a connection that is not part of the thread-local context or any current @@ -1857,16 +3872,16 @@ correctly, also relations set up against a mapper with inherited mappers will create joins against the table that is specific to the mapper itself (i.e. and not any tables that are inherited/are further down the inheritance chain), this can be overridden by using custom primary/secondary joins. -- added J.Ellis patch to mapper.py so that selectone() throws an exception -if query returns more than one object row, selectfirst() to not throw the +- added J.Ellis patch to mapper.py so that selectone() throws an exception +if query returns more than one object row, selectfirst() to not throw the exception. also adds selectfirst_by (synonymous with get_by) and selectone_by - added onupdate parameter to Column, will exec SQL/python upon an update statement.Also adds "for_update=True" to all DefaultGenerator subclasses -- added support for Oracle table reflection contributed by Andrija Zaric; +- added support for Oracle table reflection contributed by Andrija Zaric; still some bugs to work out regarding composite primary keys/dictionary selection - checked in an initial Firebird module, awaiting testing. -- added sql.ClauseParameters dictionary object as the result for -compiled.get_params(), does late-typeprocessing of bind parameters so +- added sql.ClauseParameters dictionary object as the result for +compiled.get_params(), does late-typeprocessing of bind parameters so that the original values are easier to access - more docs for indexes, column defaults, connection pooling, engine construction - overhaul to the construction of the types system. uses a simpler inheritance @@ -1874,8 +3889,8 @@ pattern so that any of the generic types can be easily subclassed, with no need for TypeDecorator. - added "convert_unicode=False" parameter to SQLEngine, will cause all String types to perform unicode encoding/decoding (makes Strings act like Unicodes) -- added 'encoding="utf8"' parameter to engine. the given encoding will be -used for all encode/decode calls within Unicode types as well as Strings +- added 'encoding="utf8"' parameter to engine. the given encoding will be +used for all encode/decode calls within Unicode types as well as Strings when convert_unicode=True. - improved support for mapping against UNIONs, added polymorph.py example to illustrate multi-class mapping against a UNION @@ -1887,7 +3902,7 @@ that will be passed to the backref. - SQL functions (i.e. func.foo()) can do execute()/scalar() standalone - fix to SQL functions so that the ANSI-standard functions, i.e. current_timestamp etc., do not specify parenthesis. all other functions do. -- added settattr_clean and append_clean to SmartProperty, which set +- added settattr_clean and append_clean to SmartProperty, which set attributes without triggering a "dirty" event or any history. used as: myclass.prop1.setattr_clean(myobject, 'hi') - improved support to column defaults when used by mappers; mappers will pull @@ -1922,10 +3937,10 @@ such as in an inheritance relationship, this is fixed. producing selects, inserts, etc. without any engine dependencies. builds upon new TableClause/ColumnClause lexical objects. Schema's Table/Column objects are the "physical" subclasses of them. simplifies schema/sql relationship, -extensions (like proxyengine), and speeds overall performance by a large margin. +extensions (like proxyengine), and speeds overall performance by a large margin. removes the entire getattr() behavior that plagued 0.1.1. - refactoring of how the mapper "synchronizes" data between two objects into a -separate module, works better with properties attached to a mapper that has an +separate module, works better with properties attached to a mapper that has an additional inheritance relationship to one of the related tables, also the same methodology used to synchronize parent/child objects now used by mapper to synchronize between inherited and inheriting mappers. @@ -1934,12 +3949,12 @@ check when object attributes are modified or the object is deleted - Index object fully implemented, can be constructed standalone, or via "index" and "unique" arguments on Columns. - added "convert_unicode" flag to SQLEngine, will treat all String/CHAR types -as Unicode types, with raw-byte/utf-8 translation on the bind parameter and +as Unicode types, with raw-byte/utf-8 translation on the bind parameter and result set side. - postgres maintains a list of ANSI functions that must have no parenthesis so function calls with no arguments work consistently - tables can be created with no engine specified. this will default their engine -to a module-scoped "default engine" which is a ProxyEngine. this engine can +to a module-scoped "default engine" which is a ProxyEngine. this engine can be connected via the function "global_connect". - added "refresh(*obj)" method to objectstore / Session to reload the attributes of any set of objects from the database unconditionally @@ -1950,7 +3965,7 @@ normally. broke nothing, slowed down everything. thanks to jpellerin for findi 0.1.1 - small fix to Function class so that expressions with a func.foo() use the type of the -Function object (i.e. the left side) as the type of the boolean expression, not the +Function object (i.e. the left side) as the type of the boolean expression, not the other side which is more of a moving target (changeset 1020). - creating self-referring mappers with backrefs slightly easier (but still not that easy - changeset 1019) @@ -1958,37 +3973,37 @@ changeset 1019) - psycopg1 date/time issue with None fixed (changeset 1005) - two issues related to postgres, which doesnt want to give you the "lastrowid" since oids are deprecated: - * postgres database-side defaults that are on primary key cols *do* execute + * postgres database-side defaults that are on primary key cols *do* execute explicitly beforehand, even though thats not the idea of a PassiveDefault. this is because sequences on columns get reflected as PassiveDefaults, but need to be explicitly -executed on a primary key col so we know what we just inserted. - * if you did add a row that has a bunch of database-side defaults on it, -and the PassiveDefault thing was working the old way, i.e. they just execute on +executed on a primary key col so we know what we just inserted. + * if you did add a row that has a bunch of database-side defaults on it, +and the PassiveDefault thing was working the old way, i.e. they just execute on the DB side, the "cant get the row back without an OID" exception that occurred also will not happen unless someone (usually the ORM) explicitly asks for it. -- fixed a glitch with engine.execute_compiled where it was making a second +- fixed a glitch with engine.execute_compiled where it was making a second ResultProxy that just got thrown away. - began to implement newer logic in object properities. you can now say myclass.attr.property, which will give you the PropertyLoader corresponding to that attribute, i.e. myclass.mapper.props['attr'] -- eager loading has been internally overhauled to use aliases at all times. more +- eager loading has been internally overhauled to use aliases at all times. more complicated chains of eager loads can now be created without any need for explicit "use aliases"-type instructions. EagerLoader code is also much simpler now. -- a new somewhat experimental flag "use_update" added to relations, indicates that -this relationship should be handled by a second UPDATE statement, either after a +- a new somewhat experimental flag "use_update" added to relations, indicates that +this relationship should be handled by a second UPDATE statement, either after a primary INSERT or before a primary DELETE. handles circular row dependencies. -- added exceptions module, all raised exceptions (except for some +- added exceptions module, all raised exceptions (except for some KeyError/AttributeError exceptions) descend from these classes. - fix to date types with MySQL, returned timedelta converted to datetime.time -- two-phase objectstore.commit operations (i.e. begin/commit) now return a +- two-phase objectstore.commit operations (i.e. begin/commit) now return a transactional object (SessionTrans), to more clearly indicate transaction boundaries. - Index object with create/drop support added to schema - fix to postgres, where it will explicitly pre-execute a PassiveDefault on a table -if it is a primary key column, pursuant to the ongoing "we cant get inserted rows +if it is a primary key column, pursuant to the ongoing "we cant get inserted rows back from postgres" issue - change to information_schema query that gets back postgres table defs, now uses explicit JOIN keyword, since one user had faster performance with 8.1 -- fix to engine.process_defaults so it works correctly with a table that has +- fix to engine.process_defaults so it works correctly with a table that has different column name/column keys (changset 982) - a column can only be attached to one table - this is now asserted - postgres time types descend from Time type diff --git a/LICENSE b/LICENSE index 73cedc004e..77490f7468 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ This is the MIT license: http://www.opensource.org/licenses/mit-license.php -Copyright (c) 2005, 2006, 2007 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael +Copyright (c) 2005, 2006, 2007, 2008 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael Bayer. Permission is hereby granted, free of charge, to any person obtaining a copy of this diff --git a/README b/README index 881ef4c709..28f843e991 100644 --- a/README +++ b/README @@ -1,17 +1,49 @@ -SQLAlchemy is licensed under an MIT-style license (see LICENSE). -Other incorporated projects may be licensed under different licenses. -All licenses allow for non-commercial and commercial use. +SQLAlchemy +++++++++++ -To install: +The Python SQL Toolkit and Object Relational Mapper - python setup.py install +Requirements +------------ -SVN checkouts also include setup.cfg file allowing setuptools to create -an svn-tagged build. +SQLAlchemy requires Python 2.3 or higher. One or more DB-API implementations +are also required for database access. See docs/intro.html for more +information on supported DB-API implementations. -Documentation is available in HTML format in the ./doc/ directory. +Installing +---------- -Information running unit tests is in README.unittests. +To install:: -good luck ! - + python setup.py install + +To use without installation, include the ``lib`` directory in your Python +path. + +Package Contents +---------------- + + doc/ + HTML documentation, including tutorials and API reference. + + examples/ + Fully commented and executable implementations for a variety of tasks. + + lib/ + SQLAlchemy. + + test/ + Unit tests for SQLAlchemy. See ``README.unittests`` for more + information. + +Help +---- + +Mailing lists, wiki, and more are available on-line at +http://www.sqlalchemy.org. + +License +------- + +SQLAlchemy is distributed under the `MIT license +`_. diff --git a/README.unittests b/README.unittests index a4a9b5197a..37f4772702 100644 --- a/README.unittests +++ b/README.unittests @@ -1,101 +1,196 @@ +===================== SQLALCHEMY UNIT TESTS ----------------------- +===================== SETUP ----- -To run unit tests (assuming unix-style commandline, adjust as needed for windows): +SQLite support is required. These instructions assume standard Python 2.4 or +higher. See the section on alternate Python implementations for information on +testing with 2.3 and other Pythons. -Python 2.4 or greater is required since the unit tests use decorators. +The 'test' directory must be on the PYTHONPATH. -cd into the SQLAlchemy distribution directory. +cd into the SQLAlchemy distribution directory -Set up the PYTHONPATH. In bash: +In bash: - export PYTHONPATH=./test/ + $ export PYTHONPATH=./test/ On windows: - set PYTHONPATH=test\ + C:\sa\> set PYTHONPATH=test\ + + Adjust any other use Unix-style paths in this README as needed. + +The unittest framework will automatically prepend the lib/ directory to +sys.path. This forces the local version of SQLAlchemy to be used, bypassing +any setuptools-installed installations (setuptools places .egg files ahead of +plain directories, even if on PYTHONPATH, unfortunately). -The unittest framework will automatically prepend the lib directory to -sys.path. This forces the local version of SQLAlchemy to be used, -bypassing any setuptools-installed installations (setuptools places -.egg files ahead of plain directories, even if on PYTHONPATH, -unfortunately). RUNNING ALL TESTS ----------------- To run all tests: - python test/alltests.py + $ python test/alltests.py + + +RUNNING INDIVIDUAL TESTS +------------------------- +Any unittest module can be run directly from the module file: + + python test/orm/mapper.py + +To run a specific test within the module, specify it as ClassName.methodname: + + python test/orm/mapper.py MapperTest.testget + COMMAND LINE OPTIONS -------------------- -Help is available via: +Help is available via --help - python test/alltests.py --help + $ python test/alltests.py --help usage: alltests.py [options] [tests...] - - options: + + Options: -h, --help show this help message and exit - --dburi=DBURI database uri (overrides --db) - --db=DB prefab database uri (sqlite, sqlite_file, postgres, - mysql, oracle, oracle8, mssql) - --mockpool use mock pool --verbose enable stdout echoing/printing - --log-info=LOG_INFO turn on info logging for (multiple OK) - --log-debug=LOG_DEBUG - turn on debug logging for (multiple OK) - --quiet suppress unittest output - --nothreadlocal dont use thread-local mod - --enginestrategy=ENGINESTRATEGY - engine strategy (plain or threadlocal, defaults to SA - default) - --coverage Dump a full coverage report after running - -NON-SQLITE DATABASES --------------------- -The prefab database connections expect to log in to localhost on the -default port as user "scott", password "tiger", database "test" (where -applicable). E.g. for postgresql the this translates to -"postgres://scott:tiger@127.0.0.1:5432/test". + --quiet suppress output + [...] -RUNNING INDIVIDUAL TESTS -------------------------- -Any unittest module can be run directly from the module file (same commandline options): +Command line options can applied to alltests.py or any individual test module. +Many are available. The most commonly used are '--db' and '--dburi'. - python test/orm/mapper.py -Additionally, to run a speciic test within the module, specify it as ClassName.methodname: +DATABASE TARGETS +---------------- + +Tests will target an in-memory SQLite database by default. To test against +another database, use the --dburi option with any standard SQLAlchemy URL: + + --dburi=postgres://user:password@localhost/test + +Use an empty database and a database user with general DBA privileges. The +test suite will be creating and dropping many tables and other DDL, and +preexisting tables will interfere with the tests + +If you'll be running the tests frequently, database aliases can save a lot of +typing. The --dbs option lists the built-in aliases and their matching URLs: + + $ python test/alltests.py --dbs + Available --db options (use --dburi to override) + mysql mysql://scott:tiger@127.0.0.1:3306/test + oracle oracle://scott:tiger@127.0.0.1:1521 + postgres postgres://scott:tiger@127.0.0.1:5432/test + [...] + +To run tests against an aliased database: + + $ python test/alltests.py --db=postgres + +To customize the URLs with your own users or hostnames, make a simple .ini +file called `test.cfg` at the top level of the SQLAlchemy source distribution +or a `.satest.cfg` in your home directory: + + [db] + postgres=postgres://myuser:mypass@localhost/mydb + +Your custom entries will override the defaults and you'll see them reflected +in the output of --dbs. - python test/orm/mapper.py MapperTest.testget CONFIGURING LOGGING ---------------------- -Logging is now available via Python's logging package. Any area of SQLAlchemy can be logged -through the unittest interface, such as: +------------------- +SQLAlchemy logs its activity and debugging through Python's logging package. +Any log target can be directed to the console with command line options, such +as: + + $ python test/orm/unitofwork.py --log-info=sqlalchemy.orm.mapper \ + --log-debug=sqlalchemy.pool --log-info=sqlalchemy.engine -Log mapper configuration, connection pool checkouts, and SQL statement execution: +This would log mapper configuration, connection pool checkouts, and SQL +statement execution. - python test/orm/unitofwork.py --log-info=sqlalchemy.orm.mapper --log-debug=sqlalchemy.pool --log-info=sqlalchemy.engine BUILT-IN COVERAGE REPORTING ------------------------------ -Coverage is now integrated through the coverage.py module, included in the './test/' directory. Running the test suite with -the --coverage switch will generate a local file ".coverage" containing coverage details, and a report will be printed -to standard output with an overview of the coverage gathered from the last unittest run (the file is deleted between runs). +Coverage is tracked with coverage.py module, included in the './test/' +directory. Running the test suite with the --coverage switch will generate a +local file ".coverage" containing coverage details, and a report will be +printed to standard output with an overview of the coverage gathered from the +last unittest run (the file is deleted between runs). + +After the suite has been run with --coverage, an annotated version of any +source file can be generated, marking statements that are executed with > and +statements that are missed with !, by running the coverage.py utility with the +"-a" (annotate) option, such as: + + $ python ./test/testlib/coverage.py -a ./lib/sqlalchemy/sql.py + +This will create a new annotated file ./lib/sqlalchemy/sql.py,cover. Pretty +cool! + + +TESTING NEW DIALECTS +-------------------- +You can use the SQLAlchemy test suite to test any new database dialect in +development. All possible database features will be exercised by default. +Test decorators are provided that can exclude unsupported tests for a +particular dialect. You'll see them all over the source, feel free to add +your dialect to them or apply new decorations to existing tests as required. + +It's fine to start out with very broad exclusions, e.g. "2-phase commit is not +supported on this database" and later refine that as needed "2-phase commit is +not available until server version 8". + +To be considered for inclusion in the SQLAlchemy distribution, a dialect must +be integrated with the standard test suite. Dialect-specific tests can be +placed in the 'dialects/' directory. Comprehensive testing of +database-specific column types and their proper reflection are a very good +place to start. + +When working through the tests, start with 'engine' and 'sql' tests. 'engine' +performs a wide range of transaction tests that might deadlock on a brand-new +dialect- try disabling those if you're having problems and revisit them later. + +Once the 'sql' tests are passing, the 'orm' tests should pass as well, modulo +any adjustments needed for SQL features the ORM uses that might not be +available in your database. But if an 'orm' test requires changes to your +dialect or the SQLAlchemy core to pass, there's a test missing in 'sql'! Any +time you can spend boiling down the problem to it's essential sql roots and +adding a 'sql' test will be much appreciated. + +The test suite is very effective at illuminating bugs and inconsistencies in +an underlying DB-API (or database!) implementation. Workarounds are almost +always possible. If you hit a wall, join us on the mailing list or, better, +IRC! -After the suite has been run with --coverage, an annotated version of any source file can be generated -marking statements that are executed with > and statements that are missed with !, by running the coverage.py -utility with the "-a" (annotate) option, such as: - python ./test/coverage.py -a ./lib/sqlalchemy/sql.py +ALTERNATE PYTHON IMPLEMENTATIONS +-------------------------------- +The test suite restricts itself to largely Python 2.3-level constructs and +standard library features, with the notable exception of decorators, which are +used extensively throughout the suite. + +A source transformation tool is included that allows testing on Python 2.3 or +any other Python implementation that lacks @decorator support. + +To use it: + + $ python test/clone.py -c --filter=py23 test23 + +This will copy the test/ directory structure into test23/, with @decorators in +the source code transformed into 2.3-friendly syntax. -which will create a new annotated file ./lib/sqlalchemy/sql.py,cover . Pretty cool ! TIPS ---- -When running the tests on postgres, postgres gets slower and slower each time you run the tests. -This seems to be related to the constant creation/dropping of tables. Running a "VACUUM FULL" -on the database will speed it up again. +Postgres: The tests require an 'alt_schema' and 'alt_schema2' to be present in +the testing database. + +Postgres: When running the tests on postgres, postgres can get slower and +slower each time you run the tests. This seems to be related to the constant +creation/dropping of tables. Running a "VACUUM FULL" on the database will +speed it up again. diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000..ef52a64807 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.4.6 diff --git a/doc/build/content/adv_datamapping.txt b/doc/build/content/adv_datamapping.txt deleted file mode 100644 index 033fb33594..0000000000 --- a/doc/build/content/adv_datamapping.txt +++ /dev/null @@ -1,1043 +0,0 @@ -[alpha_api]: javascript:alphaApi() -[alpha_implementation]: javascript:alphaImplementation() - -Advanced Data Mapping {@name=advdatamapping} -====================== - -This section details all the options available to Mappers, as well as advanced patterns. - -To start, heres the tables we will work with again: - - {python} - from sqlalchemy import * - - metadata = MetaData() - - # a table to store users - users_table = Table('users', metadata, - Column('user_id', Integer, primary_key = True), - Column('user_name', String(40)), - Column('password', String(80)) - ) - - # a table that stores mailing addresses associated with a specific user - addresses_table = Table('addresses', metadata, - Column('address_id', Integer, primary_key = True), - Column('user_id', Integer, ForeignKey("users.user_id")), - Column('street', String(100)), - Column('city', String(80)), - Column('state', String(2)), - Column('zip', String(10)) - ) - - # a table that stores keywords - keywords_table = Table('keywords', metadata, - Column('keyword_id', Integer, primary_key = True), - Column('name', VARCHAR(50)) - ) - - # a table that associates keywords with users - userkeywords_table = Table('userkeywords', metadata, - Column('user_id', INT, ForeignKey("users")), - Column('keyword_id', INT, ForeignKey("keywords")) - ) - - -### More On Mapper Properties {@name=properties} - -#### Overriding Column Names {@name=colname} - -When mappers are constructed, by default the column names in the Table metadata are used as the names of attributes on the mapped class. This can be customzed within the properties by stating the key/column combinations explicitly: - - {python} - user_mapper = mapper(User, users_table, properties={ - 'id' : users_table.c.user_id, - 'name' : users_table.c.user_name, - }) - -In the situation when column names overlap in a mapper against multiple tables, columns may be referenced together with a list: - - {python} - # join users and addresses - usersaddresses = sql.join(users_table, addresses_table, users_table.c.user_id == addresses_table.c.user_id) - m = mapper(User, usersaddresses, - properties = { - 'id' : [users_table.c.user_id, addresses_table.c.user_id], - } - ) - -#### Overriding Properties {@name=overriding} - -A common request is the ability to create custom class properties that override the behavior of setting/getting an attribute. Currently, the easiest way to do this in SQLAlchemy is how it would be done in any Python program; define your attribute with a different name, such as "_attribute", and use a property to get/set its value. The mapper just needs to be told of the special name: - - {python} - class MyClass(object): - def _set_email(self, email): - self._email = email - def _get_email(self): - return self._email - email = property(_get_email, _set_email) - - mapper(MyClass, mytable, properties = { - # map the '_email' attribute to the "email" column - # on the table - '_email': mytable.c.email - }) - -It is also possible to route the the `select_by` and `get_by` functions on `Query` using the new property name, by establishing a `synonym`: - - {python} - mapper(MyClass, mytable, properties = { - # map the '_email' attribute to the "email" column - # on the table - '_email': mytable.c.email, - - # make a synonym 'email' - 'email' : synonym('_email') - }) - - # now you can select_by(email) - result = session.query(MyClass).select_by(email='john@smith.com') - -Synonym can be established with the flag "proxy=True", to create a class-level proxy to the actual property. This has the effect of creating a fully functional synonym on class instances: - - {python} - mapper(MyClass, mytable, properties = { - '_email': mytable.c.email - 'email' : synonym('_email', proxy=True) - }) - - x = MyClass() - x.email = 'john@doe.com' - - >>> x._email - 'john@doe.com' - -#### Entity Collections {@name=entitycollections} - -Mapping a one-to-many or many-to-many relationship results in a collection of values accessible through an attribute on the parent instance. By default, this collection is a `list`: - - {python} - mapper(Parent, properties={ - children = relation(Child) - }) - - parent = Parent() - parent.children.append(Child()) - print parent.children[0] - -Collections are not limited to lists. Sets, mutable sequences and almost any other Python object that can act as a container can be used in place of the default list. - - {python} - # use a set - mapper(Parent, properties={ - children = relation(Child, collection_class=set) - }) - - parent = Parent() - child = Child() - parent.children.add(child) - assert child in parent.children - -##### Custom Entity Collections {@name=customcollections} - -You can use your own types for collections as well. For most cases, simply inherit from `list` or `set` and add the custom behavior. - -Collections in SQLAlchemy are transparently *instrumented*. Instrumentation means that normal operations on the collection are tracked and result in changes being written to the database at flush time. Additionally, collection operations can fire *events* which indicate some secondary operation must take place. Examples of a secondary operation include saving the child item in the parent's `Session` (i.e. the `save-update` cascade), as well as synchronizing the state of a bi-directional relationship (i.e. a `backref`). - -The collections package understands the basic interface of lists, sets and dicts and will automatically apply instrumentation to those built-in types and their subclasses. Object-derived types that implement a basic collection interface are detected and instrumented via duck-typing: - - {python} - class ListLike(object): - def __init__(self): - self.data = [] - def append(self, item): - self.data.append(item) - def remove(self, item): - self.data.remove(item) - def extend(self, items): - self.data.extend(items) - def __iter__(self): - return iter(self.data) - def foo(self): - return 'foo' - -`append`, `remove`, and `extend` are known list-like methods, and will be instrumented automatically. `__iter__` is not a mutator method and won't be instrumented, and `foo` won't be either. - -Duck-typing (i.e. guesswork) isn't rock-solid, of course, so you can be explicit about the interface you are implementing by providing an `__emulates__` class attribute: - - {python} - class SetLike(object): - __emulates__ = set - - def __init__(self): - self.data = set() - def append(self, item): - self.data.add(item) - def remove(self, item): - self.data.remove(item) - def __iter__(self): - return iter(self.data) - -This class looks list-like because of `append`, but `__emulates__` forces it to set-like. `remove` is known to be part of the set interface and will be instrumented. - -But this class won't work quite yet: a little glue is needed to adapt it for use by SQLAlchemy. The ORM needs to know which methods to use to append, remove and iterate over members of the collection. When using a type like `list` or `set`, the appropriate methods are well-known and used automatically when present. This set-like class does not provide the expected `add` method, so we must supply an explicit mapping for the ORM via a decorator. - -##### Collection Decorators {@name=collectiondecorators} - -Decorators can be used to tag the individual methods the ORM needs to manage collections. Use them when your class doesn't quite meet the regular interface for its container type, or you simply would like to use a different method to get the job done. - - {python} - from sqlalchemy.orm.collections import collection - - class SetLike(object): - __emulates__ = set - - def __init__(self): - self.data = set() - - @collection.appender - def append(self, item): - self.data.add(item) - - def remove(self, item): - self.data.remove(item) - - def __iter__(self): - return iter(self.data) - -And that's all that's needed to complete the example. SQLAlchemy will add instances via the `append` method. `remove` and `__iter__` are the default methods for sets and will be used for removing and iteration. Default methods can be changed as well: - - {python} - from sqlalchemy.orm.collections import collection - - class MyList(list): - @collection.remover - def zark(self, item): - # do something special... - - @collection.iterator - def hey_use_this_instead_for_iteration(self): - # ... - -There is no requirement to be list-, or set-like at all. Collection classes can be any shape, so long as they have the append, remove and iterate interface marked for SQLAlchemy's use. Append and remove methods will be called with a mapped entity as the single argument, and iterator methods are called with no arguments and must return an iterator. - -##### Dictionary-Based Collections {@name=dictcollections} - -A `dict` can be used as a collection, but a keying strategy is needed to map entities loaded by the ORM to key, value pairs. The [collections](rel:docstrings_sqlalchemy.orm.collections) package provides several built-in types for dictionary-based collections: - - {python} - from sqlalchemy.orm.collections import column_mapped_collection, attr_mapped_collection, mapped_collection - - mapper(Item, items_table, properties={ - # key by column - notes = relation(Note, collection_class=column_mapped_collection(notes_table.c.keyword)) - # or named attribute - notes2 = relation(Note, collection_class=attr_mapped_collection('keyword')) - # or any callable - notes3 = relation(Note, collection_class=mapped_collection(lambda entity: entity.a + entity.b)) - }) - - # ... - item = Item() - item.notes['color'] = Note('color', 'blue') - print item.notes['color'] - -These functions each provide a `dict` subclass with decorated `set` and `remove` methods and the keying strategy of your choice. - -The [collections.MappedCollection](rel:docstrings_sqlalchemy.orm.collections.MappedCollection) class can be used as a base class for your custom types or as a mix-in to quickly add `dict` collection support to other classes. It uses a keying function to delegate to `__setitem__` and `__delitem__`: - - {python} - from sqlalchemy.util import OrderedDict - from sqlalchemy.orm.collections import MappedCollection - - class NodeMap(OrderedDict, MappedCollection): - """Holds 'Node' objects, keyed by the 'name' attribute with insert order maintained.""" - - def __init__(self, *args, **kw): - MappedCollection.__init__(self, keyfunc=lambda node: node.name) - OrderedDict.__init__(self, *args, **kw) - -The ORM understands the `dict` interface just like lists and sets, and will automatically instrument all dict-like methods if you choose to subclass `dict` or provide dict-like collection behavior in a duck-typed class. You must decorate appender and remover methods, however- there are no compatible methods in the basic dictionary interface for SQLAlchemy to use by default. Iteration will go through `itervalues()` unless otherwise decorated. - -##### Instrumentation and Custom Types {@name=adv_collections} - -Many custom types and existing library classes can be used as a entity collection type as-is without further ado. However, it is important to note that the instrumentation process _will_ modify the type, adding decorators around methods automatically. - -The decorations are lightweight and no-op outside of relations, but they do add unneeded overhead when triggered elsewhere. When using a library class as a collection, it can be good practice to use the "trivial subclass" trick to restrict the decorations to just your usage in relations. For example: - - {python} - class MyAwesomeList(some.great.library.AwesomeList): - pass - - # ... relation(..., collection_class=MyAwesomeList) - -The ORM uses this approach for built-ins, quietly substituting a trivial subclass when a `list`, `set` or `dict` is used directly. - -The collections package provides additional decorators and support for authoring custom types. See the [package documentation](rel:docstrings_sqlalchemy.orm.collections) for more information and discussion of advanced usage and Python 2.3-compatible decoration options. - -#### Custom Join Conditions {@name=customjoin} - -When creating relations on a mapper, most examples so far have illustrated the mapper and relationship joining up based on the foreign keys of the tables they represent. in fact, this "automatic" inspection can be completely circumvented using the `primaryjoin` and `secondaryjoin` arguments to `relation`, as in this example which creates a User object which has a relationship to all of its Addresses which are in Boston: - - {python} - class User(object): - pass - class Address(object): - pass - - mapper(Address, addresses_table) - mapper(User, users_table, properties={ - 'boston_addresses' : relation(Address, primaryjoin= - and_(users_table.c.user_id==Address.c.user_id, - Addresses.c.city=='Boston')) - }) - -Many to many relationships can be customized by one or both of `primaryjoin` and `secondaryjoin`, shown below with just the default many-to-many relationship explicitly set: - - {python} - class User(object): - pass - class Keyword(object): - pass - mapper(Keyword, keywords_table) - mapper(User, users_table, properties={ - 'keywords':relation(Keyword, secondary=userkeywords_table, - primaryjoin=users_table.c.user_id==userkeywords_table.c.user_id, - secondaryjoin=userkeywords_table.c.keyword_id==keywords_table.c.keyword_id - ) - }) - -#### Lazy/Eager Joins Multiple Times to One Table {@name=multiplejoin} - -The previous example leads in to the idea of joining against the same table multiple times. Below is a User object that has lists of its Boston and New York addresses: - - {python} - mapper(User, users_table, properties={ - 'boston_addresses' : relation(Address, primaryjoin= - and_(users_table.c.user_id==Address.c.user_id, - Addresses.c.city=='Boston')), - 'newyork_addresses' : relation(Address, primaryjoin= - and_(users_table.c.user_id==Address.c.user_id, - Addresses.c.city=='New York')), - }) - -Both lazy and eager loading support multiple joins equally well. - -#### Deferred Column Loading {@name=deferred} - -This feature allows particular columns of a table to not be loaded by default, instead being loaded later on when first referenced. It is essentailly "column-level lazy loading". This feature is useful when one wants to avoid loading a large text or binary field into memory when its not needed. Individual columns can be lazy loaded by themselves or placed into groups that lazy-load together. - - {python} - book_excerpts = Table('books', db, - Column('book_id', Integer, primary_key=True), - Column('title', String(200), nullable=False), - Column('summary', String(2000)), - Column('excerpt', String), - Column('photo', Binary) - ) - - class Book(object): - pass - - # define a mapper that will load each of 'excerpt' and 'photo' in - # separate, individual-row SELECT statements when each attribute - # is first referenced on the individual object instance - mapper(Book, book_excerpts, properties = { - 'excerpt' : deferred(book_excerpts.c.excerpt), - 'photo' : deferred(book_excerpts.c.photo) - }) - -Deferred columns can be placed into groups so that they load together: - - {python} - book_excerpts = Table('books', db, - Column('book_id', Integer, primary_key=True), - Column('title', String(200), nullable=False), - Column('summary', String(2000)), - Column('excerpt', String), - Column('photo1', Binary), - Column('photo2', Binary), - Column('photo3', Binary) - ) - - class Book(object): - pass - - # define a mapper with a 'photos' deferred group. when one photo is referenced, - # all three photos will be loaded in one SELECT statement. The 'excerpt' will - # be loaded separately when it is first referenced. - mapper(Book, book_excerpts, properties = { - 'excerpt' : deferred(book_excerpts.c.excerpt), - 'photo1' : deferred(book_excerpts.c.photo1, group='photos'), - 'photo2' : deferred(book_excerpts.c.photo2, group='photos'), - 'photo3' : deferred(book_excerpts.c.photo3, group='photos') - }) - -You can defer or undefer columns at the `Query` level with the `options` method: - - {python} - query = session.query(Book) - query.options(defer('summary')).all() - query.options(undefer('excerpt')).all() - -#### Working with Large Collections - -SQLAlchemy relations are generally simplistic; the lazy loader loads in the full list of child objects when accessed, and the eager load builds a query that loads the full list of child objects. Additionally, when you are deleting a parent object, SQLAlchemy ensures that it has loaded the full list of child objects so that it can mark them as deleted as well (or to update their parent foreign key to NULL). It does not issue an en-masse "delete from table where parent_id=?" type of statement in such a scenario. This is because the child objects themselves may also have further dependencies, and additionally may also exist in the current session in which case SA needs to know their identity so that their state can be properly updated. - -So there are several techniques that can be used individually or combined together to address these issues, in the context of a large collection where you normally would not want to load the full list of relationships: - -* Use `lazy=None` to disable child object loading (i.e. noload) - - {python} - mapper(MyClass, table, properties=relation{ - 'children':relation(MyOtherClass, lazy=None) - }) - -* To load child objects, just use a query. Of particular convenience is that `Query` is a generative object, so you can return -it as is, allowing additional criterion to be added as needed: - - {python} - class Organization(object): - def __init__(self, name): - self.name = name - member_query = property(lambda self: object_session(self).query(Member).with_parent(self)) - - myorg = sess.query(Organization).get(5) - - # get all members - members = myorg.member_query.list() - - # query a subset of members using LIMIT/OFFSET - members = myorg.member_query[5:10] - -* Use `passive_deletes=True` to disable child object loading on a DELETE operation, in conjunction with "ON DELETE (CASCADE|SET NULL)" on your database to automatically cascade deletes to child objects. Note that "ON DELETE" is not supported on SQLite, and requires `InnoDB` tables when using MySQL: - - {python} - mytable = Table('mytable', meta, - Column('id', Integer, primary_key=True), - ) - - myothertable = Table('myothertable', meta, - Column('id', Integer, primary_key=True), - Column('parent_id', Integer), - ForeignKeyConstraint(['parent_id'],['mytable.id'], ondelete="CASCADE"), - ) - - mmapper(MyOtherClass, myothertable) - - mapper(MyClass, mytable, properties={ - 'children':relation(MyOtherClass, passive_deletes=True) - }) - -* As an alternative to using "ON DELETE CASCADE", for very simple scenarios you can create a simple `MapperExtension` that will issue a DELETE for child objects before the parent object is deleted: - - {python} - class DeleteMemberExt(MapperExtension): - def before_delete(self, mapper, connection, instance): - connection.execute(member_table.delete(member_table.c.org_id==instance.org_id)) - - mapper(Organization, org_table, extension=DeleteMemberExt(), properties = { - 'members' : relation(Member, lazy=None, passive_deletes=True, cascade="all, delete-orphan") - }) - -Note that this approach is not nearly as efficient or general-purpose as "ON DELETE CASCADE", since the database itself can cascade the operation along any number of tables. - -The latest distribution includes an example `examples/collection/large_collection.py` which illustrates most of these techniques. - -#### Relation Options {@name=relationoptions} - -Options which can be sent to the `relation()` function. For arguments to `mapper()`, see [advdatamapping_mapperoptions](rel:advdatamapping_mapperoptions). - -* **association** - Deprecated; as of version 0.3.0 the association keyword is synonomous with applying the "all, delete-orphan" cascade to a "one-to-many" relationship. SA can now automatically reconcile a "delete" and "insert" operation of two objects with the same "identity" in a flush() operation into a single "update" statement, which is the pattern that "association" used to indicate. See the updated example of association mappings in [datamapping_association](rel:datamapping_association). -* **backref** - indicates the name of a property to be placed on the related mapper's class that will handle this relationship in the other direction, including synchronizing the object attributes on both sides of the relation. Can also point to a `backref()` construct for more configurability. See [datamapping_relations_backreferences](rel:datamapping_relations_backreferences). -* **cascade** - a string list of cascade rules which determines how persistence operations should be "cascaded" from parent to child. For a description of cascade rules, see [datamapping_relations_lifecycle](rel:datamapping_relations_lifecycle) and [unitofwork_cascade](rel:unitofwork_cascade). -* **collection_class** - a class or function that returns a new list-holding object. will be used in place of a plain list for storing elements. See [advdatamapping_properties_customlist](rel:advdatamapping_properties_customlist). -* **foreign_keys** - a list of columns which are to be used as "foreign key" columns. this parameter should be used in conjunction with explicit -`primaryjoin` and `secondaryjoin` (if needed) arguments, and the columns within the `foreign_keys` list should be present within those join conditions. Normally, `relation()` will inspect the columns within the join conditions to determine which columns are the "foreign key" columns, based on information in the `Table` metadata. Use this argument when no ForeignKey's are present in the join condition, or to override the table-defined foreign keys. -* **foreignkey** - deprecated. use the `foreign_keys` argument for foreign key specification, or `remote_side` for "directional" logic. -* **lazy=True** - specifies how the related items should be loaded. a value of True indicates they should be loaded lazily when the property is first accessed. A value of False indicates they should be loaded by joining against the parent object query, so parent and child are loaded in one round trip (i.e. eagerly). A value of None indicates the related items are not loaded by the mapper in any case; the application will manually insert items into the list in some other way. In all cases, items added or removed to the parent object's collection (or scalar attribute) will cause the appropriate updates and deletes upon flush(), i.e. this option only affects load operations, not save operations. -* **order_by** - indicates the ordering that should be applied when loading these items. See the section [advdatamapping_orderby](rel:advdatamapping_orderby) for details. -* **passive_deletes=False** - Indicates if lazy-loaders should not be executed during the `flush()` process, which normally occurs in order to locate all existing child items when a parent item is to be deleted. Setting this flag to True is appropriate when `ON DELETE CASCADE` rules have been set up on the actual tables so that the database may handle cascading deletes automatically. This strategy is useful particularly for handling the deletion of objects that have very large (and/or deep) child-object collections. See the example in [advdatamapping_properties_working](rel:advdatamapping_properties_working). -* **post_update** - this indicates that the relationship should be handled by a second UPDATE statement after an INSERT or before a DELETE. Currently, it also will issue an UPDATE after the instance was UPDATEd as well, although this technically should be improved. This flag is used to handle saving bi-directional dependencies between two individual rows (i.e. each row references the other), where it would otherwise be impossible to INSERT or DELETE both rows fully since one row exists before the other. Use this flag when a particular mapping arrangement will incur two rows that are dependent on each other, such as a table that has a one-to-many relationship to a set of child rows, and also has a column that references a single child row within that list (i.e. both tables contain a foreign key to each other). If a `flush()` operation returns an error that a "cyclical dependency" was detected, this is a cue that you might want to use `post_update` to "break" the cycle. -* **primaryjoin** - a ClauseElement that will be used as the primary join of this child object against the parent object, or in a many-to-many relationship the join of the primary object to the association table. By default, this value is computed based on the foreign key relationships of the parent and child tables (or association table). -* **private=False** - deprecated. setting `private=True` is the equivalent of setting `cascade="all, delete-orphan"`, and indicates the lifecycle of child objects should be contained within that of the parent. See the example in [datamapping_relations_lifecycle](rel:datamapping_relations_lifecycle). -* **remote_side** - used for self-referential relationships, indicates the column or list of columns that form the "remote side" of the relationship. See the examples in [advdatamapping_selfreferential](rel:advdatamapping_selfreferential). -* **secondary** - for a many-to-many relationship, specifies the intermediary table. The `secondary` keyword argument should generally only be used for a table that is not otherwise expressed in any class mapping. In particular, using the [Association Object Pattern](rel:datamapping_association) is generally mutually exclusive against using the `secondary` keyword argument. -* **secondaryjoin** - a ClauseElement that will be used as the join of an association table to the child object. By default, this value is computed based on the foreign key relationships of the association and child tables. -* **uselist=(True|False)** - a boolean that indicates if this property should be loaded as a list or a scalar. In most cases, this value is determined automatically by `relation()`, based on the type and direction of the relationship - one to many forms a list, many to one forms a scalar, many to many is a list. If a scalar is desired where normally a list would be present, such as a bi-directional one-to-one relationship, set uselist to False. -* **viewonly=False** - when set to True, the relation is used only for loading objects within the relationship, and has no effect on the unit-of-work flush process. Relations with viewonly can specify any kind of join conditions to provide additional views of related objects onto a parent object. Note that the functionality of a viewonly relationship has its limits - complicated join conditions may not compile into eager or lazy loaders properly. If this is the case, use an alternative method, such as those described in [advdatamapping_properties_working](rel:advdatamapping_properties_working), [advdatamapping_resultset](rel:advdatamapping_resultset), or [advdatamapping_selects](rel:advdatamapping_selects). - -### Controlling Ordering {@name=orderby} - -By default, mappers will attempt to ORDER BY the "oid" column of a table, or the primary key column, when selecting rows. This can be modified in several ways. - -The "order_by" parameter can be sent to a mapper, overriding the per-engine ordering if any. A value of None means that the mapper should not use any ordering. A non-None value, which can be a column, an `asc` or `desc` clause, or an array of either one, indicates the ORDER BY clause that should be added to all select queries: - - {python} - # disable all ordering - mapper = mapper(User, users_table, order_by=None) - - # order by a column - mapper = mapper(User, users_table, order_by=users_tableusers_table.c.user_id) - - # order by multiple items - mapper = mapper(User, users_table, order_by=[users_table.c.user_id, desc(users_table.c.user_name)]) - -"order_by" can also be specified with queries, overriding all other per-engine/per-mapper orderings: - - {python} - # order by a column - l = query.filter(users_table.c.user_name=='fred').order_by(users_table.c.user_id).all() - - # order by multiple criterion - l = query.filter(users_table.c.user_name=='fred').order_by([users_table.c.user_id, desc(users_table.c.user_name)]) - -The "order_by" property can also be specified on a `relation()` which will control the ordering of the collection: - - {python} - mapper(Address, addresses_table) - - # order address objects by address id - mapper(User, users_table, properties = { - 'addresses' : relation(Address, order_by=addresses_table.c.address_id) - }) - - -### Limiting Rows Combined with Eager Loads {@name=limits} - -As indicated in the docs on `Query`, you can limit rows using `limit()` and `offset()`. However, things get tricky when dealing with eager relationships, since a straight LIMIT of rows will interfere with the eagerly-loaded rows. So here is what SQLAlchemy will do when you use limit or offset with an eager relationship: - - {python} - class User(object): - pass - class Address(object): - pass - mapper(User, users_table, properties={ - 'addresses' : relation(mapper(Address, addresses_table), lazy=False) - }) - r = session.query(User).filter(User.c.user_name.like('F%')).limit(20).offset(10).all() - {opensql}SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.password AS users_password, addresses.address_id AS addresses_address_id, - addresses.user_id AS addresses_user_id, addresses.street AS addresses_street, - addresses.city AS addresses_city, addresses.state AS addresses_state, - addresses.zip AS addresses_zip - FROM - (SELECT users.user_id FROM users WHERE users.user_name LIKE %(users_user_name)s - ORDER BY users.oid LIMIT 20 OFFSET 10) AS rowcount, - users LEFT OUTER JOIN addresses ON users.user_id = addresses.user_id - WHERE rowcount.user_id = users.user_id ORDER BY users.oid, addresses.oid - {'users_user_name': 'F%'} - -The main WHERE clause as well as the limiting clauses are coerced into a subquery; this subquery represents the desired result of objects. A containing query, which handles the eager relationships, is joined against the subquery to produce the result. This is something to keep in mind as it's a complex query which may be problematic on databases with poor support for LIMIT, such as Oracle which does not support it natively. - -### Mapping a Class with Table Inheritance {@name=inheritance} - -Inheritance in databases comes in three forms: *single table inheritance*, where several types of classes are stored in one table, *concrete table inheritance*, where each type of class is stored in its own table, and *joined table inheritance*, where the parent/child classes are stored in their own tables that are joined together in a select. - -There is also the ability to load "polymorphically", which is that a single query loads objects of multiple types at once. - -SQLAlchemy supports all three kinds of inheritance. Additionally, true "polymorphic" loading is supported in a straightfoward way for single table inheritance, and has some more manually-configured features that can make it happen for concrete and multiple table inheritance. - -Working examples of polymorphic inheritance come with the distribution in the directory `examples/polymorphic`. - -Here are the classes we will use to represent an inheritance relationship: - - {python} - class Employee(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return self.__class__.__name__ + " " + self.name - - class Manager(Employee): - def __init__(self, name, manager_data): - self.name = name - self.manager_data = manager_data - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.manager_data - - class Engineer(Employee): - def __init__(self, name, engineer_info): - self.name = name - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.engineer_info - -Each class supports a common `name` attribute, while the `Manager` class has its own attribute `manager_data` and the `Engineer` class has its own attribute `engineer_info`. - -#### Single Table Inheritance - -This will support polymorphic loading via the `Employee` mapper. - - {python} - employees_table = Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('manager_data', String(50)), - Column('engineer_info', String(50)), - Column('type', String(20)) - ) - - employee_mapper = mapper(Employee, employees_table, polymorphic_on=employees_table.c.type) - manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer') - -#### Concrete Table Inheritance - -Without polymorphic loading, you just define a separate mapper for each class. - - {python title="Concrete Inheritance, Non-polymorphic"} - managers_table = Table('managers', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('manager_data', String(50)), - ) - - engineers_table = Table('engineers', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('engineer_info', String(50)), - ) - - manager_mapper = mapper(Manager, managers_table) - engineer_mapper = mapper(Engineer, engineers_table) - -With polymorphic loading, the SQL query to do the actual polymorphic load must be constructed, usually as a UNION. There is a helper function to create these UNIONS called `polymorphic_union`. - - {python title="Concrete Inheritance, Polymorphic"} - pjoin = polymorphic_union({ - 'manager':managers_table, - 'engineer':engineers_table - }, 'type', 'pjoin') - - employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer') - -#### Joined Table Inheritance - -Like concrete table inheritance, this can be done non-polymorphically, or with a little more complexity, polymorphically: - - {python title="Multiple Table Inheritance, Non-polymorphic"} - employees = Table('employees', metadata, - Column('person_id', Integer, primary_key=True), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('employees.person_id'), primary_key=True), - Column('engineer_info', String(50)), - ) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('employees.person_id'), primary_key=True), - Column('manager_data', String(50)), - ) - - person_mapper = mapper(Employee, employees) - mapper(Engineer, engineers, inherits=person_mapper) - mapper(Manager, managers, inherits=person_mapper) - -Polymorphically, joined-table inheritance is easier than concrete, as a simple outer join can usually work: - - {python title="Joined Table Inheritance, Polymorphic"} - person_join = people.outerjoin(engineers).outerjoin(managers) - - person_mapper = mapper(Person, people, select_table=person_join,polymorphic_on=people.c.type, polymorphic_identity='person') - mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - -In SQLAlchemy 0.4, the above mapper setup can load polymorphically *without* the join as well, by issuing distinct queries for each subclasses' table. - -The join condition in a joined table inheritance structure can be specified explicitly, using `inherit_condition`: - - {python} - AddressUser.mapper = mapper( - AddressUser, - addresses_table, inherits=User.mapper, - inherit_condition=users_table.c.user_id==addresses_table.c.user_id - ) - -### Mapping a Class against Multiple Tables {@name=joins} - -Mappers can be constructed against arbitrary relational units (called `Selectables`) as well as plain `Tables`. For example, The `join` keyword from the SQL package creates a neat selectable unit comprised of multiple tables, complete with its own composite primary key, which can be passed in to a mapper as the table. - - {python} - # a class - class AddressUser(object): - pass - - # define a Join - j = join(users_table, addresses_table) - - # map to it - the identity of an AddressUser object will be - # based on (user_id, address_id) since those are the primary keys involved - m = mapper(AddressUser, j, properties={ - 'user_id':[users_table.c.user_id, addresses_table.c.user_id] - }) - -A second example: - - {python} - # many-to-many join on an association table - j = join(users_table, userkeywords, - users_table.c.user_id==userkeywords.c.user_id).join(keywords, - userkeywords.c.keyword_id==keywords.c.keyword_id) - - # a class - class KeywordUser(object): - pass - - # map to it - the identity of a KeywordUser object will be - # (user_id, keyword_id) since those are the primary keys involved - m = mapper(KeywordUser, j, properties={ - 'user_id':[users_table.c.user_id, userkeywords.c.user_id], - 'keyword_id':[userkeywords.c.keyword_id, keywords.c.keyword_id] - }) - -In both examples above, "composite" columns were added as properties to the mappers; these are aggregations of multiple columns into one mapper property, which instructs the mapper to keep both of those columns set at the same value. - -### Mapping a Class against Arbitrary Selects {@name=selects} - -Similar to mapping against a join, a plain select() object can be used with a mapper as well. Below, an example select which contains two aggregate functions and a group_by is mapped to a class: - - {python} - s = select([customers, - func.count(orders).label('order_count'), - func.max(orders.price).label('highest_order')], - customers.c.customer_id==orders.c.customer_id, - group_by=[c for c in customers.c] - ).alias('somealias') - class Customer(object): - pass - - m = mapper(Customer, s) - -Above, the "customers" table is joined against the "orders" table to produce a full row for each customer row, the total count of related rows in the "orders" table, and the highest price in the "orders" table, grouped against the full set of columns in the "customers" table. That query is then mapped against the Customer class. New instances of Customer will contain attributes for each column in the "customers" table as well as an "order_count" and "highest_order" attribute. Updates to the Customer object will only be reflected in the "customers" table and not the "orders" table. This is because the primary keys of the "orders" table are not represented in this mapper and therefore the table is not affected by save or delete operations. - -### Multiple Mappers for One Class {@name=multiple} - -The first mapper created for a certain class is known as that class's "primary mapper." Other mappers can be created as well, these come in two varieties. - -* **secondary mapper** - this is a mapper that must be constructed with the keyword argument `non_primary=True`, and represents a load-only mapper. Objects that are loaded with a secondary mapper will have their save operation processed by the primary mapper. It is also invalid to add new `relation()`s to a non-primary mapper. To use this mapper with the Session, specify it to the `query` method: - -example: - - {python} - # primary mapper - mapper(User, users_table) - - # make a secondary mapper to load User against a join - othermapper = mapper(User, users_table.join(someothertable), non_primary=True) - - # select - result = session.query(othermapper).select() - -* **entity name mapper** - this is a mapper that is a fully functioning primary mapper for a class, which is distinguished from the regular primary mapper by an `entity_name` parameter. Instances loaded with this mapper will be totally managed by this new mapper and have no connection to the original one. Most methods on `Session` include an optional `entity_name` parameter in order to specify this condition. - -example: - - {python} - # primary mapper - mapper(User, users_table) - - # make an entity name mapper that stores User objects in another table - mapper(User, alternate_users_table, entity_name='alt') - - # make two User objects - user1 = User() - user2 = User() - - # save one in in the "users" table - session.save(user1) - - # save the other in the "alternate_users_table" - session.save(user2, entity_name='alt') - - session.flush() - - # select from the alternate mapper - session.query(User, entity_name='alt').select() - -### Self Referential Mappers {@name=selfreferential} - -A self-referential mapper is a mapper that is designed to operate with an *adjacency list* table. This is a table that contains one or more foreign keys back to itself, and is usually used to create hierarchical tree structures. SQLAlchemy's default model of saving items based on table dependencies is not sufficient in this case, as an adjacency list table introduces dependencies between individual rows. Fortunately, SQLAlchemy will automatically detect a self-referential mapper and do the extra lifting to make it work. - - {python} - # define a self-referential table - trees = Table('treenodes', engine, - Column('node_id', Integer, primary_key=True), - Column('parent_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True), - Column('node_name', String(50), nullable=False), - ) - - # treenode class - class TreeNode(object): - pass - - # mapper defines "children" property, pointing back to TreeNode class, - # with the mapper unspecified. it will point back to the primary - # mapper on the TreeNode class. - TreeNode.mapper = mapper(TreeNode, trees, properties={ - 'children' : relation( - TreeNode, - cascade="all" - ), - } - ) - -This kind of mapper goes through a lot of extra effort when saving and deleting items, to determine the correct dependency graph of nodes within the tree. - -A self-referential mapper where there is more than one relationship on the table requires that all join conditions be explicitly spelled out. Below is a self-referring table that contains a "parent_node_id" column to reference parent/child relationships, and a "root_node_id" column which points child nodes back to the ultimate root node: - - {python} - # define a self-referential table with several relations - trees = Table('treenodes', engine, - Column('node_id', Integer, primary_key=True), - Column('parent_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True), - Column('root_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True), - Column('node_name', String(50), nullable=False), - ) - - # treenode class - class TreeNode(object): - pass - - # define the "children" property as well as the "root" property - mapper(TreeNode, trees, properties={ - 'children' : relation( - TreeNode, - primaryjoin=trees.c.parent_node_id==trees.c.node_id - cascade="all", - backref=backref("parent", remote_side=[trees.c.node_id]) - ), - 'root' : relation( - TreeNode, - primaryjoin=trees.c.root_node_id=trees.c.node_id, - remote_side=[trees.c.node_id], - uselist=False - ) - } - ) - -The "root" property on a TreeNode is a many-to-one relationship. By default, a self-referential mapper declares relationships as one-to-many, so the extra parameter `remote_side`, pointing to a column or list of columns on the remote side of a relationship, is needed to indicate a "many-to-one" self-referring relationship (note the previous keyword argument `foreignkey` is deprecated). -Both TreeNode examples above are available in functional form in the `examples/adjacencytree` directory of the distribution. - -### Statement and Result-Set ORM Queries {@name=resultset} - -Take any textual statement, constructed statement or result set and feed it into a Query to produce objects. Below, we define two class/mapper combinations, issue a SELECT statement, and send the result object to the method `instances()` method on `Query`: - - {python} - class User(object): - pass - - class Address(object): - pass - - mapper(User, users_table) - - mapper(Address, addresses_table) - - # select users and addresses in one query - # use_labels is so that the user_id column in both tables are distinguished - s = select([users_table, addresses_table], users_table.c.user_id==addresses_table.c.user_id, use_labels=True) - - # execute it, and process the results, asking for both User and Address objects - r = session.query(User, Address).instances(s.execute()) - - # result rows come back as tuples - for entry in r: - user = r[0] - address = r[1] - -Alternatively, the `from_statement()` method may be used with either a textual string or SQL construct: - - {python} - s = select([users_table, addresses_table], users_table.c.user_id==addresses_table.c.user_id, use_labels=True) - - r = session.query(User, Address).from_statement(s).all() - - for entry in r: - user = r[0] - address = r[1] - -#### Combining Eager Loads with Statement/Result Set Queries - -When full statement/result loads are used with `Query`, SQLAlchemy does not affect the SQL query itself, and therefore has no way of tacking on its own `LEFT [OUTER] JOIN` conditions that are normally used to eager load relationships. If the query being constructed is created in such a way that it returns rows not just from a parent table (or tables) but also returns rows from child tables, the result-set mapping can be notified as to which additional properties are contained within the result set. This is done using the `contains_eager()` query option, which specifies the name of the relationship to be eagerly loaded. - - {python} - # mapping is the users->addresses mapping - mapper(User, users_table, properties={ - 'addresses':relation(Address, addresses_table) - }) - - # define a query on USERS with an outer join to ADDRESSES - statement = users_table.outerjoin(addresses_table).select(use_labels=True) - - # construct a Query object which expects the "addresses" results - query = session.query(User).options(contains_eager('addresses')) - - # get results normally - r = query.instances(statement.execute()) - -If the "eager" portion of the statement is "alisaed", the `alias` keyword argument to `contains_eager()` may be used to indicate it. This is a string alias name or reference to an actual `Alias` object: - - {python} - # use an alias of the addresses table - adalias = addresses_table.alias('adalias') - - # define a query on USERS with an outer join to adalias - statement = users_table.outerjoin(adalias).select(use_labels=True) - - # construct a Query object which expects the "addresses" results - query = session.query(User).options(contains_eager('addresses', alias=adalias)) - - # get results normally - {sql}r = query.from_statement(query).all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, adalias.address_id AS adalias_address_id, - adalias.user_id AS adalias_user_id, adalias.email_address AS adalias_email_address, (...other columns...) - FROM users LEFT OUTER JOIN email_addresses AS adalias ON users.user_id = adalias.user_id - -In the case that the main table itself is also aliased, the `contains_alias()` option can be used: - - {python} - # define an aliased UNION called 'ulist' - statement = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist') - - # add on an eager load of "addresses" - statement = statement.outerjoin(addresses).select(use_labels=True) - - # create query, indicating "ulist" is an alias for the main table, "addresses" property should - # be eager loaded - query = create_session().query(User).options(contains_alias('ulist'), contains_eager('addresses')) - - # results - r = query.instances(statement.execute()) - - -### Mapper Keyword Arguments {@name=mapperoptions} - -Keyword arguments which can be used with the `mapper()` function. For arguments to `relation()`, see [advdatamapping_properties_relationoptions](rel:advdatamapping_properties_relationoptions). - -* **allow_column_override** - if True, allows the usage of a `relation()` which has the same name as a column in the mapped table. The table column will no longer be mapped. -* **allow_null_pks=False** - indicates that composite primary keys where one or more (but not all) columns contain NULL is a valid primary key. Primary keys which contain NULL values usually indicate that a result row does not contain an entity and should be skipped. -* **always_refresh=False** - if True, all query operations for this mapped class will overwrite all data within object instances that already exist within the session, erasing any in-memory changes with whatever information was loaded from the database. Note that this option bypasses the usage patterns for which the Session is designed - negative side effects should be expected, and usage issues involving this flag are not supported. For a better way to refresh data, use `query.load()`, `session.refresh()`, `session.expunge()`, or `session.clear()`. -* **batch=True** - when False, indicates that when a mapper is persisting a list of instances, each instance will be fully saved to the database before moving onto the next instance. Normally, inserts and updates are batched together per-table, such as for an inheriting mapping that spans multiple tables. This flag is for rare circumstances where custom `MapperExtension` objects are used to attach logic to `before_insert()`, `before_update()`, etc., and the user-defined logic requires that the full persistence of each instance must be completed before moving onto the next (such as logic which queries the tables for the most recent ID). Note that this flag has a significant impact on the efficiency of a large save operation. -* **column_prefix** - a string which will be prepended to the "key" name of all Columns when creating column-based properties from the given Table. Does not affect explicitly specified column-based properties. Setting `column_prefix='_'` is equivalent to defining all column-based properties as `_columnname=table.c.columnname`. See [advdatamapping_properties_colname](rel:advdatamapping_properties_colname) for information on overriding column-based attribute names. -* **concrete** - if True, indicates this mapper should use concrete table inheritance with its parent mapper. Requires `inherits` to be set. -* **entity_name** - defines this mapping as local to a particular class of entities within a single class. Allows alternate persistence mappings for a single class. See [advdatamapping_multiple](rel:advdatamapping_multiple). -* **extension** - a MapperExtension instance or list of MapperExtension instances which will be applied to all operations by this Mapper. See [advdatamapping_extending](rel:advdatamapping_extending). -* **inherits** - another Mapper or class for which this Mapper will have an inheritance relationship with. See the examples in [advdatamapping_inheritance](rel:advdatamapping_inheritance). -* **inherit_condition** - for joined table inheritance, a SQL expression (constructed ClauseElement) which will define how the two tables are joined; -defaults to a natural join between the two tables. -* **non_primary=False** - if True, construct a Mapper that will define only the selection of instances, not their persistence. It essentially creates a mapper that can be used for querying but does not define how instances of the class are stored. A non_primary mapper is always created after a regular primary mapper has already been created for the class. To use one, send it in place of the class argument when creating a query, such as `session.query(somemapper)`. Note that it is usually invalid to define additional relationships on a non_primary mapper as they will conflict with those of the primary. See [advdatamapping_multiple](rel:advdatamapping_multiple). -* **order_by** - a single Column or list of Columns for which selection operations should use as the default ordering for entities. Defaults to the OID/ROWID of the table if any, or the first primary key column of the table. See [advdatamapping_orderby](rel:advdatamapping_orderby). -* **polymorphic_on** - used with mappers in an inheritance relationship, a Column which will identify the class/mapper combination to be used with a particular row. requires the `polymorphic_identity` value to be set for all mappers in the inheritance hierarchy. -* **polymorphic_identity** - a value which will be stored in the Column denoted by `polymorphic_on`, corresponding to the "class identity" of this mapper. See [advdatamapping_inheritance](rel:advdatamapping_inheritance). -* **primary_key** - a list of Column objects which define the "primary key" to be used against this mapper's selectable unit. The mapper normally determines these automatically from the given `local_table` of the mapper combined against any inherited tables. When this argument is specified, the primary keys of the mapped table if any are disregarded in place of the columns given. This can be used to provide primary key identity to a table that has no PKs defined at the schema level, or to modify what defines "identity" for a particular table. -* **properties** - a dictionary mapping the string names of object attributes to MapperProperty instances, which define the persistence behavior of that attribute. Note that the columns in the mapped table are automatically converted into ColumnProperty instances based on the "key" property of each Column (although they can be overridden using this dictionary). -* **select_table** - used with polymorphic mappers, this is a `Selectable` which will take the place of the `Mapper`'s main table argument when -performing queries. -* **version_id_col** - a Column which must have an integer type that will be used to keep a running "version id" of mapped entities in the database. This is used during save operations to ensure that no other thread or process has updated the instance during the lifetime of the entity, else a ConcurrentModificationError exception is thrown. - -### Extending Mapper {@name=extending} - -Mappers can have functionality augmented or replaced at many points in its execution via the usage of the MapperExtension class. This class is just a series of "hooks" where various functionality takes place. An application can make its own MapperExtension objects, overriding only the methods it needs. Methods that are not overridden return the special value `sqlalchemy.orm.mapper.EXT_PASS`, which indicates the operation should proceed as normally. - - {python} - class MapperExtension(object): - """base implementation for an object that provides overriding behavior to various - Mapper functions. For each method in MapperExtension, a result of EXT_PASS indicates - the functionality is not overridden.""" - def get_session(self): - """called to retrieve a contextual Session instance with which to - register a new object. Note: this is not called if a session is - provided with the __init__ params (i.e. _sa_session)""" - return EXT_PASS - def select_by(self, query, *args, **kwargs): - """overrides the select_by method of the Query object""" - return EXT_PASS - def select(self, query, *args, **kwargs): - """overrides the select method of the Query object""" - return EXT_PASS - def create_instance(self, mapper, selectcontext, row, class_): - """called when a new object instance is about to be created from a row. - the method can choose to create the instance itself, or it can return - None to indicate normal object creation should take place. - - mapper - the mapper doing the operation - - selectcontext - SelectionContext corresponding to the instances() call - - row - the result row from the database - - class_ - the class we are mapping. - """ - return EXT_PASS - def append_result(self, mapper, selectcontext, row, instance, identitykey, result, isnew): - """called when an object instance is being appended to a result list. - - If this method returns EXT_PASS, it is assumed that the mapper should do the appending, else - if this method returns any other value or None, it is assumed that the append was handled by this method. - - mapper - the mapper doing the operation - - selectcontext - SelectionContext corresponding to the instances() call - - row - the result row from the database - - instance - the object instance to be appended to the result - - identitykey - the identity key of the instance - - result - list to which results are being appended - - isnew - indicates if this is the first time we have seen this object instance in the current result - set. if you are selecting from a join, such as an eager load, you might see the same object instance - many times in the same result set. - """ - return EXT_PASS - def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew): - """called right before the mapper, after creating an instance from a row, passes the row - to its MapperProperty objects which are responsible for populating the object's attributes. - If this method returns EXT_PASS, it is assumed that the mapper should do the appending, else - if this method returns any other value or None, it is assumed that the append was handled by this method. - - Essentially, this method is used to have a different mapper populate the object: - - def populate_instance(self, mapper, selectcontext, instance, row, identitykey, isnew): - othermapper.populate_instance(selectcontext, instance, row, identitykey, isnew, frommapper=mapper) - return True - """ - return EXT_PASS - def before_insert(self, mapper, connection, instance): - """called before an object instance is INSERTed into its table. - - this is a good place to set up primary key values and such that arent handled otherwise.""" - return EXT_PASS - def before_update(self, mapper, connection, instance): - """called before an object instance is UPDATED""" - return EXT_PASS - def after_update(self, mapper, connection, instance): - """called after an object instance is UPDATED""" - return EXT_PASS - def after_insert(self, mapper, connection, instance): - """called after an object instance has been INSERTed""" - return EXT_PASS - def before_delete(self, mapper, connection, instance): - """called before an object instance is DELETEed""" - return EXT_PASS - def after_delete(self, mapper, connection, instance): - """called after an object instance is DELETEed""" - return EXT_PASS -To use MapperExtension, make your own subclass of it and just send it off to a mapper: - - {python} - m = mapper(User, users_table, extension=MyExtension()) - -Multiple extensions will be chained together and processed in order; they are specified as a list: - - {python} - m = mapper(User, users_table, extension=[ext1, ext2, ext3]) - diff --git a/doc/build/content/copyright.txt b/doc/build/content/copyright.txt index bc76e9f640..a9e2cb228f 100644 --- a/doc/build/content/copyright.txt +++ b/doc/build/content/copyright.txt @@ -3,7 +3,7 @@ Appendix: Copyright {@name=copyright} This is the MIT license: http://www.opensource.org/licenses/mit-license.php -Copyright (c) 2005, 2006, 2007 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael +Copyright (c) 2005, 2006, 2007, 2008 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael Bayer. Permission is hereby granted, free of charge, to any person obtaining a copy of this diff --git a/doc/build/content/datamapping.txt b/doc/build/content/datamapping.txt deleted file mode 100644 index d35fecfb34..0000000000 --- a/doc/build/content/datamapping.txt +++ /dev/null @@ -1,1044 +0,0 @@ -[alpha_api]: javascript:alphaApi() -[alpha_implementation]: javascript:alphaImplementation() - -Data Mapping {@name=datamapping} -============ - -### Basic Data Mapping {@name=datamapping} - -Data mapping describes the process of defining `Mapper` objects, which associate table metadata with user-defined classes. - -When a `Mapper` is created to associate a `Table` object with a class, all of the columns defined in the `Table` object are associated with the class via property accessors, which add overriding functionality to the normal process of setting and getting object attributes. These property accessors keep track of changes to object attributes; these changes will be stored to the database when the application "flushes" the current state of objects. This pattern is called a *Unit of Work* pattern. - -### Synopsis {@name=synopsis} - -Starting with a `Table` definition and a minimal class construct, the two are associated with each other via the `mapper()` function [[api](rel:docstrings_sqlalchemy.orm.mapper_Mapper)], which generates an object called a `Mapper`. SA associates the class and all instances of that class with this particular `Mapper`, which is then stored in a global registry. - - {python} - from sqlalchemy import * - - # metadata - meta = MetaData() - - # table object - users_table = Table('users', meta, - Column('user_id', Integer, primary_key=True), - Column('user_name', String(16)), - Column('fullname', String(100)), - Column('password', String(20)) - ) - - # class definition - class User(object): - pass - - # create a mapper and associate it with the User class. - mapper(User, users_table) - -Thats all for configuration. Next, we will create an `Engine` and bind it to a `Session`, which represents a local collection of mapped objects to be operated upon. - - {python} - # engine - engine = create_engine("sqlite://mydb.db") - - # session - session = create_session(bind=engine) - -The `session` represents a "workspace" which can load objects and persist changes to the database. Note also that the `bind` parameter is optional; if the underlying `Table` objects are bound as described in [metadata_tables_binding](rel:metadata_tables_binding), it's not needed. A `Session` [[doc](rel:unitofwork)] [[api](rel:docstrings_sqlalchemy.orm.session_Session)] is best created as local to a particular set of related data operations, such as scoped within a function call, or within a single application request cycle. Next we illustrate a rudimental query which will load a single object instance. We will modify one of its attributes and persist the change back to the database. - - {python} - # select - {sql}user = session.query(User).filter_by(user_name='fred').first() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users - WHERE users.user_name = :users_user_name ORDER BY users.oid LIMIT 1 - {'users_user_name': 'fred'} - - # modify - user.user_name = 'fred jones' - - # flush - saves everything that changed - # within the scope of our Session - {sql}session.flush() - BEGIN - UPDATE users SET user_name=:user_name - WHERE users.user_id = :user_id - [{'user_name': 'fred jones', 'user_id': 1}] - COMMIT - -Things to note from the above include that the loaded `User` object has an attribute named `user_name` on it, which corresponds to the `user_name` column in `users_table`; this attribute was configured at the class level by the `Mapper`, as part of it's post-initialization process (this process occurs normally when the mapper is first used). Our modify operation on this attribute caused the object to be marked as "dirty", which was picked up automatically within the subsequent `flush()` process. The `flush()` is the point at which all changes to objects within the `Session` are persisted to the database, and the `User` object is no longer marked as "dirty" until it is again modified. - -### The Query Object {@name=query} - -The method `session.query(class_or_mapper)` returns a `Query` object [[api](rel:docstrings_sqlalchemy.orm.query_Query)]. `Query` implements methods which are used to produce and execute select statements tailored for loading object instances. It returns object instances in all cases; usually as a list, but in some cases scalar objects, or lists of tuples which contain multiple kinds of objects and sometimes individual scalar values. - -A `Query` is created from the `Session`, relative to a particular class we wish to load. - - {python} - # get a query from a Session based on class: - query = session.query(User) - -Alternatively, an actual `Mapper` instance can be specified instead of a class: - - {python} - # locate the mapper corresponding to the User class - usermapper = class_mapper(User) - - # create query against the User mapper - query = session.query(usermapper) - -A query which joins across multiple tables may also be used to request multiple entities, such as: - - {python} - query = session.query(User, Address) - -Once we have a query, we can start loading objects. The methods `filter()` and `filter_by()` handle narrowing results, and the methods `all()`, `one()`, and `first()` exist to return all, exactly one, or the first result of the total set of results. Note that all methods are *generative*, meaning that on each call that doesn't return results, you get a **new** `Query` instance. - -The `filter_by()` method works with keyword arguments, which are combined together via AND: - - {python} - {sql}result = session.query(User).filter_by(name='john', fullname='John Smith').all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users - WHERE users.user_name = :users_user_name AND users.fullname = :users_fullname - ORDER BY users.oid - {'users_user_name': 'john', 'users_fullname': 'John Smith'} - -Whereas `filter()` works with constructed SQL expressions, i.e. those described in [sql](rel:sql): - - {python} - {sql}result = session.query(User).filter(users_table.c.name=='john').all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users - WHERE users.user_name = :users_user_name - ORDER BY users.oid - {'users_user_name': 'john'} - -Sometimes, constructing SQL via expressions can be cumbersome. For quick SQL expression, the `filter()` method can also accomodate straight text: - - {python} - {sql}result = session.query(User).filter("user_id>224").all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users - WHERE users.user_id>224 - ORDER BY users.oid - {} - -When using text, bind parameters can be specified the same way as in a `text()` clause, using a colon. To specify the bind parameter values, use the `params()` method: - - {python} - {sql}result = session.query(User).filter("user_id>:value").params(value=224).all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users - WHERE users.user_id>:value - ORDER BY users.oid - {'value': 224} - -Multiple `filter()` and `filter_by()` expressions may be combined together. The resulting statement groups them using AND. - - {python} - result = session.query(User).filter(users_table.c.user_id>224).filter_by(name='john'). - {sql} filter(users.c.fullname=='John Smith').all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users - WHERE users.user_id>:users_user_id AND users.user_name = :users_user_name - AND users.fullname = :users_fullname - ORDER BY users.oid - {'users_user_name': 'john', 'users_fullname': 'John Smith', 'users_user_id': 224} - -`filter_by()`'s keyword arguments can also take mapped object instances as comparison arguments. We'll illustrate this later when we talk about object relationships. - -Note that all conjunctions are available explicitly, such as `and_()` and `or_()`, when using `filter()`: - - {python} - result = session.query(User).filter( - and_(users_table.c.user_id>224, or_(users_table.c.name=='john', users_table.c.name=='ed')) - ).all() - -Its also straightforward to use an entirely string-based statement, using `from_statement()`; just ensure that the columns clause of the statement contains the column names normally used by the mapper (here illustrated using an asterisk): - - {python} - {sql}result = session.query(User).from_statement("SELECT * FROM users").all() - SELECT * FROM users - {} - -`from_statement()` can also accomodate `select()` constructs: - - {python} - result = session.query(User).from_statement( - select([users], users.c.name<'e', having=users.c.name==func.max(users.c.name), group_by=[c for c in users.c]) - {sql} ).all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users - WHERE users.user_name>:users_user_name HAVING users.user_name == max(users.user_name) - GROUP BY users.user_id, users.user_name, users.fullname, users.password - ORDER BY users.oid - {'users_user_name': 'e'} - -Any set of filtered criterion (or no criterion) can be distilled into a count of rows using `count()`: - - {python} - {sql}num = session.query(Users).filter(users_table.c.user_id>224).count() - SELECT count(users.id) FROM users WHERE users.user_id>:users_user_id - {'users_user_id': 224} - -Rows are limited and offset using `limit()` and `offset()`: - - {python} - {sql}result = session.query(User).limit(20).offset(5).all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users ORDER BY users.oid - LIMIT 20 OFFSET 5 - {} - -And ordering is applied, using `Column` objects and related SQL constructs, with `order_by()`: - - {python} - {sql}result = session.query(User).order_by(desc(users_table.c.user_name)).all() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users ORDER BY users.user_name DESC - {} - -The `first()` and `one()` methods will also limit rows, and both will return a single object, instead of a list. In the case of `first()`, rows are limited to just one, and the result is returned as a scalar. In the case of `one()`, rows are limited to *two*; however, only one is returned. If two rows are matched, an exception is raised. - - {python} - # load the first result - user = session.query(User).first() - - # load exactly *one* result - if more than one result matches, an exception is raised - user = session.query(User).filter_by(name='jack').one() - -The `Query`, when evaluated as an iterator, executes results immediately, using whatever state has been built up: - - {python} - {sql}result = list(session.query(User)) - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users ORDER BY users.oid - {} - -Array indexes and slices work too, adding the corresponding LIMIT and OFFSET clauses: - - {python} - {sql}result = list(session.query(User)[1:3]) - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users ORDER BY users.oid - LIMIT 2 OFFSET 1 - {} - -A scalar index returns a scalar result immediately: - - {python} - {sql}user = session.query(User)[2] - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.fullname AS users_fullname, users.password AS users_password - FROM users ORDER BY users.oid - LIMIT 1 OFFSET 2 - {} - -Theres also a way to combine scalar results with objects, using `add_column()`. This is often used for functions and aggregates. When `add_column()` (or its cousin `add_entity()`, described later) is used, tuples are returned: - - {python} - result = session.query(User).add_column(func.max(users_table.c.name)).group_by([c for c in users_table.c]).all() - for r in result: - print "user:", r[0] - print "max name:", r[1] - -Later in this chapter, we'll discuss how to configure relations between mapped classes. Once that's done, we'll discuss how to use table joins in [datamapping_joins](rel:datamapping_joins). - -#### Loading by Primary Key {@name=primarykey} - -The `get()` method loads a single instance, given the primary key value of the desired entity: - - {python} - # load user with primary key 15 - user = query.get(15) - -The `get()` method, because it has the actual primary key value of the instance, can return an already-loaded instance from the `Session` without performing any SQL. It is the only result-returning method on `Query` that does not issue SQL to the database in all cases. - -To issue a composite primary key to `get()`, use a tuple. The order of the arguments matches that of the primary key columns of the table: - - {python} - myobj = query.get((27, 3, 'receipts')) - -Another special method on `Query` is `load()`. This method has the same signature as `get()`, except it always **refreshes** the returned instance with the latest data from the database. This is in fact a unique behavior, since as we will see in the [unitofwork](rel:unitofwork) chapter, most `Query` methods do not reload the contents of instances which are already present in the session. - -#### Column Objects Available via their Mapped Class {@name=columnsonclass} - -Some of the above examples above illustrate the usage of the mapper's Table object to provide the columns for a WHERE Clause. These columns are also accessible off of the mapped class directly. When a mapper is assigned to a class, it also attaches a special property accessor `c` to the class itself, which can be used just like that of a `Table` object to access the columns of the table: - - {python} - userlist = session.query(User).filter(User.c.user_id==12).first() - -In version 0.4 of SQLAlchemy, the "c" prefix will no longer be needed. - -### Saving Objects {@name=saving} - -When objects corresponding to mapped classes are created or manipulated, all changes are logged by the `Session` object. The changes are then written to the database when an application calls `flush()`. This pattern is known as a *Unit of Work*, and has many advantages over saving individual objects or attributes on those objects with individual method invocations. Domain models can be built with far greater complexity with no concern over the order of saves and deletes, excessive database round-trips and write operations, or deadlocking issues. The `flush()` operation batches its SQL statements into a transaction, and can also perform optimistic concurrency checks (using a version id column) to ensure the proper number of rows were in fact affected. - -The Unit of Work is a powerful tool, and has some important concepts that should be understood in order to use it effectively. See the [unitofwork](rel:unitofwork) section for a full description on all its operations. - -When a mapper is created, the target class has its mapped properties decorated by specialized property accessors that track changes. New objects by default must be explicitly added to the `Session` using the `save()` method: - - {python} - mapper(User, users_table) - - # create a new User - myuser = User() - myuser.user_name = 'jane' - myuser.password = 'hello123' - - # create another new User - myuser2 = User() - myuser2.user_name = 'ed' - myuser2.password = 'lalalala' - - # create a Session and save them - sess = create_session() - sess.save(myuser) - sess.save(myuser2) - - # load a third User from the database - {sql}myuser3 = sess.query(User).filter_by(name='fred').all()[0] - SELECT users.user_id AS users_user_id, - users.user_name AS users_user_name, users.password AS users_password - FROM users WHERE users.user_name = :users_user_name - {'users_user_name': 'fred'} - - myuser3.user_name = 'fredjones' - - # save all changes - {sql}session.flush() - UPDATE users SET user_name=:user_name - WHERE users.user_id =:users_user_id - [{'users_user_id': 1, 'user_name': 'fredjones'}] - INSERT INTO users (user_name, password) VALUES (:user_name, :password) - {'password': 'hello123', 'user_name': 'jane'} - INSERT INTO users (user_name, password) VALUES (:user_name, :password) - {'password': 'lalalala', 'user_name': 'ed'} - -The mapped class can also specify whatever methods and/or constructor it wants: - - {python} - class User(object): - def __init__(self, user_name, password): - self.user_id = None - self.user_name = user_name - self.password = password - def get_name(self): - return self.user_name - def __repr__(self): - return "User id %s name %s password %s" % (repr(self.user_id), - repr(self.user_name), repr(self.password)) - mapper(User, users_table) - - sess = create_session() - u = User('john', 'foo') - sess.save(u) - {sql}session.flush() - INSERT INTO users (user_name, password) VALUES (:user_name, :password) - {'password': 'foo', 'user_name': 'john'} - - >>> u - User id 1 name 'john' password 'foo' - -Note that the **__init__() method is not called when the instance is loaded**. This is so that classes can define operations that are specific to their initial construction which are not re-called when the object is restored from the database, and is similar in concept to how Python's `pickle` module calls `__new__()` when deserializing instances. To allow `__init__()` to be called at object load time, or to define any other sort of on-load operation, create a `MapperExtension` which supplies the `create_instance()` method (see [advdatamapping_extending](rel:advdatamapping_extending), as well as the example in the FAQ). - -### Defining and Using Relationships {@name=relations} - -So that covers how to map the columns in a table to an object, how to load objects, create new ones, and save changes. The next step is how to define an object's relationships to other database-persisted objects. This is done via the `relation` function [[doc](rel:advdatamapping_properties_relationoptions)][[api](rel:docstrings_sqlalchemy.orm_modfunc_relation)] provided by the `orm` module. - -#### One to Many {@name=onetomany} - -With our User class, lets also define the User has having one or more mailing addresses. First, the table metadata: - - {python} - from sqlalchemy import * - - metadata = MetaData() - - # define user table - users_table = Table('users', metadata, - Column('user_id', Integer, primary_key=True), - Column('user_name', String(16)), - Column('password', String(20)) - ) - - # define user address table - addresses_table = Table('addresses', metadata, - Column('address_id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey("users.user_id")), - Column('street', String(100)), - Column('city', String(80)), - Column('state', String(2)), - Column('zip', String(10)) - ) - -Of importance here is the addresses table's definition of a *foreign key* relationship to the users table, relating the user_id column into a parent-child relationship. When a `Mapper` wants to indicate a relation of one object to another, the `ForeignKey` relationships are the default method by which the relationship is determined (options also exist to describe the relationships explicitly). - -So then lets define two classes, the familiar `User` class, as well as an `Address` class: - - {python} - class User(object): - def __init__(self, user_name, password): - self.user_name = user_name - self.password = password - - class Address(object): - def __init__(self, street, city, state, zip): - self.street = street - self.city = city - self.state = state - self.zip = zip - -And then a `Mapper` that will define a relationship of the `User` and the `Address` classes to each other as well as their table metadata. We will add an additional mapper keyword argument `properties` which is a dictionary relating the names of class attributes to database relationships, in this case a `relation` object against a newly defined mapper for the Address class: - - {python} - mapper(Address, addresses_table) - mapper(User, users_table, properties = { - 'addresses' : relation(Address) - } - ) - -Lets do some operations with these classes and see what happens: - - {python} - engine = create_engine('sqlite:///mydb.db') - - # create tables - metadata.create_all(engine) - - session = create_session(bind=engine) - - u = User('jane', 'hihilala') - u.addresses.append(Address('123 anywhere street', 'big city', 'UT', '76543')) - u.addresses.append(Address('1 Park Place', 'some other city', 'OK', '83923')) - - session.save(u) - session.flush() - {opensql}INSERT INTO users (user_name, password) VALUES (:user_name, :password) - {'password': 'hihilala', 'user_name': 'jane'} - INSERT INTO addresses (user_id, street, city, state, zip) VALUES (:user_id, :street, :city, :state, :zip) - {'city': 'big city', 'state': 'UT', 'street': '123 anywhere street', 'user_id':1, 'zip': '76543'} - INSERT INTO addresses (user_id, street, city, state, zip) VALUES (:user_id, :street, :city, :state, :zip) - {'city': 'some other city', 'state': 'OK', 'street': '1 Park Place', 'user_id':1, 'zip': '83923'} - -A lot just happened there! The `Mapper` figured out how to relate rows in the addresses table to the users table, and also upon flush had to determine the proper order in which to insert rows. After the insert, all the `User` and `Address` objects have their new primary and foreign key attributes populated. - -Also notice that when we created a `Mapper` on the `User` class which defined an `addresses` relation, the newly created `User` instance magically had an "addresses" attribute which behaved like a list. This list is in reality a Python `property` which will return an instance of `sqlalchemy.orm.attributes.InstrumentedList`. This is a generic collection-bearing object which can represent lists, sets, dictionaries, or any user-defined collection class. By default it represents a list: - - {python} - del u.addresses[1] - u.addresses.append(Address('27 New Place', 'Houston', 'TX', '34839')) - - session.flush() - - {opensql}UPDATE addresses SET user_id=:user_id - WHERE addresses.address_id = :addresses_address_id - [{'user_id': None, 'addresses_address_id': 2}] - INSERT INTO addresses (user_id, street, city, state, zip) - VALUES (:user_id, :street, :city, :state, :zip) - {'city': 'Houston', 'state': 'TX', 'street': '27 New Place', 'user_id': 1, 'zip': '34839'} - -Note that when creating a relation with the `relation()` function, the target can either be a class, in which case the primary mapper for that class is used as the target, or a `Mapper` instance itself, as returned by the `mapper()` function. - -#### Lifecycle Relations {@name=lifecycle} - -In the previous example, a single address was removed from the `addresses` attribute of a `User` object, resulting in the corresponding database row being updated to have a user_id of `None`. But now, theres a mailing address with no user_id floating around in the database of no use to anyone. How can we avoid this ? This is acheived by using the `cascade` parameter of `relation`: - - {python} - clear_mappers() # clear mappers from the previous example - mapper(Address, addresses_table) - mapper(User, users_table, properties = { - 'addresses' : relation(Address, cascade="all, delete-orphan") - } - ) - - del u.addresses[1] - u.addresses.append(Address('27 New Place', 'Houston', 'TX', '34839')) - - session.flush() - {opensql}INSERT INTO addresses (user_id, street, city, state, zip) - VALUES (:user_id, :street, :city, :state, :zip) - {'city': 'Houston', 'state': 'TX', 'street': '27 New Place', 'user_id': 1, 'zip': '34839'} - DELETE FROM addresses WHERE addresses.address_id = :address_id - [{'address_id': 2}] - -In this case, with the `delete-orphan` **cascade rule** set, the element that was removed from the addresses list was also removed from the database. Specifying `cascade="all, delete-orphan"` means that every persistence operation performed on the parent object will be *cascaded* to the child object or objects handled by the relation, and additionally that each child object cannot exist without being attached to a parent. Such a relationship indicates that the **lifecycle** of the `Address` objects are bounded by that of their parent `User` object. - -Cascading is described fully in [unitofwork_cascade](rel:unitofwork_cascade). - -#### Backreferences {@name=backreferences} - -By creating relations with the `backref` keyword, a bi-directional relationship can be created which will keep both ends of the relationship updated automatically, independently of database operations. Below, the `User` mapper is created with an `addresses` property, and the corresponding `Address` mapper receives a "backreference" to the `User` object via the property name `user`: - - {python} - Address = mapper(Address, addresses_table) - User = mapper(User, users_table, properties = { - 'addresses' : relation(Address, backref='user') - } - ) - - u = User('fred', 'hi') - a1 = Address('123 anywhere street', 'big city', 'UT', '76543') - a2 = Address('1 Park Place', 'some other city', 'OK', '83923') - - # append a1 to u - u.addresses.append(a1) - - # attach u to a2 - a2.user = u - - # the bi-directional relation is maintained - >>> u.addresses == [a1, a2] - True - >>> a1.user is user and a2.user is user - True - -The backreference feature also works with many-to-many relationships, which are described later. When creating a backreference, a corresponding property (i.e. a second `relation()`) is placed on the child mapper. The default arguments to this property can be overridden using the `backref()` function: - - {python} - mapper(User, users_table) - mapper(Address, addresses_table, properties={ - 'user':relation(User, backref=backref('addresses', cascade="all, delete-orphan")) - }) - - -The `backref()` function is often used to set up a bi-directional one-to-one relationship. This is because the `relation()` function by default creates a "one-to-many" relationship when presented with a primary key/foreign key relationship, but the `backref()` function can redefine the `uselist` property to make it a scalar: - - {python} - mapper(User, users_table) - mapper(Address, addresses_table, properties={ - 'user' : relation(User, backref=backref('address', uselist=False)) - }) - -### Querying with Joins {@name=joins} - -When using mappers that have relationships to other mappers, the need to specify query criterion across multiple tables arises. SQLAlchemy provides several core techniques which offer this functionality. - -One way is just to build up the join criterion yourself. This is easy to do using `filter()`: - - {python} - {sql}l = session.query(User).filter(users.c.user_id==addresses.c.user_id). - filter(addresses.c.street=='123 Green Street').all() - SELECT users.user_id AS users_user_id, - users.user_name AS users_user_name, users.password AS users_password - FROM users, addresses - WHERE users.user_id=addresses.user_id - AND addresses.street=:addresses_street - ORDER BY users.oid - {'addresses_street', '123 Green Street'} - -Above, we specified selection criterion that included columns from both the `users` and the `addresses` table. Note that in this case, we had to specify not just the matching condition to the `street` column on `addresses`, but also the join condition between the `users` and `addresses` table. If we didn't do that, we'd get a *cartesian product* of both tables. The `Query` object never "guesses" what kind of join you'd like to use, but makes it easy using the `join()` method which we'll get to in a moment. - -A way to specify joins very explicitly, using the SQL `join()` construct, is possible via the `select_from()` method on `Query`: - - {python} - {sql}l = session.query(User).select_from(users_table.join(addresses_table)). - filter(addresses_table.c.street=='123 Green Street').all() - SELECT users.user_id AS users_user_id, - users.user_name AS users_user_name, users.password AS users_password - FROM users JOIN addresses ON users.user_id=addresses.user_id - WHERE addresses.street=:addresses_street - ORDER BY users.oid - {'addresses_street', '123 Green Street'} - -But the easiest way to join is automatically, using the `join()` method on `Query`. Just give this method the path from A to B, using the name of a mapped relationship directly: - - {python} - {sql}l = session.query(User).join('addresses'). - filter(addresses_table.c.street=='123 Green Street').all() - SELECT users.user_id AS users_user_id, - users.user_name AS users_user_name, users.password AS users_password - FROM users JOIN addresses ON users.user_id=addresses.user_id - WHERE addresses.street=:addresses_street - ORDER BY users.oid - {'addresses_street', '123 Green Street'} - -Each time the `join()` is called on `Query`, the **joinpoint** of the query is moved to be that of the endpoint of the join. As above, when we joined from `users_table` to `addresses_table`, all subsequent criterion used by `filter_by()` are against the `addresses` table. If we wanted to filter back on the starting table again, we can use the `reset_joinpoint()` function: - - {python} - l = session.query(User).join('addresses'). - filter_by(street='123 Green Street'). - reset_joinpoint().filter_by(user_name='ed').all() - -With `reset_joinpoint()`, we can also issue new `join()`s which will start back from the root table. - -In all cases, we can get the `User` and the matching `Address` objects back at the same time, by telling the session we want both. This returns the results as a list of tuples: - - {python} - result = session.query(User, Address).join('addresses'). - filter(addresses_table.c.street=='123 Green Street').all() - for r in result: - print "User:", r[0] - print "Address:", r[1] - -The above syntax is shorthand for using the `add_entity()` method: - - {python} - session.query(User).add_entity(Address).join('addresses').all() - -To join across multiple relationships, specify them in a list. Below, we load a `ShoppingCart`, limiting its `cartitems` collection to the single item which has a `price` object whose `amount` column is 47.95: - - {python} - cart = session.query(ShoppingCart).join(['cartitems', 'price']).filter_by(amount=47.95).one() - -`filter_by()` can also generate joins in some cases, such as when comparing to an object instance: - - {python} - # get an instance of Address. assume its primary key identity - # is 12. - someaddress = session.query(Address).filter_by(street='123 Green Street').one() - - # look for User instances which have the - # "someaddress" instance in their "addresses" collection - {sql}l = session.query(User).filter_by(addresses=someaddress).all() - SELECT users.user_id AS users_user_id, - users.user_name AS users_user_name, users.password AS users_password - FROM users, addresses - WHERE users.user_id=addresses.user_id - AND addresses.address_id=:addresses_address_id - ORDER BY users.oid - {'addresses_addresses_id': 12} - -You can also create joins in "reverse", that is, to find an object with a certain parent. This is accomplished using `with_parent()`: - - {python} - # load a user - someuser = session.query(User).get(2) - - # load an address with that user as a parent and email address foo@bar.com - {sql}someaddresses = session.query(Address).with_parent(someuser). - filter_by(email_address="foo@bar.com").all() - SELECT addresses.address_id AS addresses_address_id, - addresses.user_id AS addresses_user_id, addresses.street AS addresses_street, - addresses.city AS addresses_city, addresses.state AS addresses_state, - addresses.zip AS addresses_zip FROM addresses - WHERE addresses.email_address = :addresses_email_address AND - addresses.user_id = :users_user_id ORDER BY addresses.oid - {'users_user_id': 1, 'addresses_email_address': 'foo@bar.com'} - -Sometimes it's necessary to create repeated joins that are independent of each other, even though they reference the same tables. Using our one-to-many setup, an example is to locate users who have two partcular email addresses. We can do this using table aliases: - - {python} - ad1 = addresses_table.alias('ad1') - ad2 = addresses_table.alias('ad2') - {sql}result = session.query(User).filter(and_( - ad1.c.user_id==users.c.user_id, - ad1.c.email_address=='foo@bar.com', - ad2.c.user_id==users.c.user_id, - ad2.c.email_address=='lala@yahoo.com' - )).all() - SELECT users.user_id AS users_user_id, - users.user_name AS users_user_name, users.password AS users_password - FROM users, addresses AS ad1, addresses AS ad2 - WHERE users.user_id=ad1.user_id - AND ad1.email_address=:ad1_email_address - AND users.user_id=ad2.user_id - AND ad2.email_address=:ad2_email_address - ORDER BY users.oid - {'ad1_email_address': 'foo@bar.com', 'ad2_email_address': 'lala@yahoo.com'} - -Version 0.4 of SQLAlchemy will include better ability to issue queries like the above with less verbosity. - -### Loading Relationships {@name=selectrelations} - -We've seen how the `relation` specifier affects the saving of an object and its child items, and also how it allows us to build joins. How to we get the actual related items loaded ? By default, the `relation()` function indicates that the related property should be attached a *lazy loader* when instances of the parent object are loaded from the database; this is just a callable function that when accessed will invoke a second SQL query to load the child objects of the parent. - - {python} - # define a user mapper - mapper(User, users_table, properties = { - 'addresses' : relation(Address) - }) - - # define an address mapper - mapper(Address, addresses_table) - - # select users where username is 'jane', get the first element of the list - # this will incur a load operation for the parent table - {sql}user = session.query(User).filter(User.c.user_name=='jane')[0] - SELECT users.user_id AS users_user_id, - users.user_name AS users_user_name, users.password AS users_password - FROM users WHERE users.user_name = :users_user_name ORDER BY users.oid - {'users_user_name': 'jane'} - - # iterate through the User object's addresses. this will incur an - # immediate load of those child items - {sql}for a in user.addresses: - SELECT addresses.address_id AS addresses_address_id, - addresses.user_id AS addresses_user_id, addresses.street AS addresses_street, - addresses.city AS addresses_city, addresses.state AS addresses_state, - addresses.zip AS addresses_zip FROM addresses - WHERE addresses.user_id = :users_user_id ORDER BY addresses.oid - {'users_user_id': 1} - - print repr(a) - -#### Eager Loading {@name=eagerload} - -Eager Loading is another way for relationships to be loaded. It describes the loading of parent and child objects across a relation using a single query. The purpose of eager loading is strictly one of performance enhancement; eager loading has **no impact** on the results of a query, except that when traversing child objects within the results, lazy loaders will not need to issue separate queries to load those child objects. - -With just a single parameter `lazy=False` specified to the relation object, the parent and child SQL queries can be joined together. - - {python} - mapper(Address, addresses_table) - mapper(User, users_table, properties = { - 'addresses' : relation(Address, lazy=False) - } - ) - - {sql}users = session.query(User).filter(User.c.user_name=='Jane').all() - SELECT users.user_name AS users_user_name, users.password AS users_password, - users.user_id AS users_user_id, addresses_4fb8.city AS addresses_4fb8_city, - addresses_4fb8.address_id AS addresses_4fb8_address_id, addresses_4fb8.user_id AS addresses_4fb8_user_id, - addresses_4fb8.zip AS addresses_4fb8_zip, addresses_4fb8.state AS addresses_4fb8_state, - addresses_4fb8.street AS addresses_4fb8_street - FROM users LEFT OUTER JOIN addresses AS addresses_4fb8 ON users.user_id = addresses_4fb8.user_id - WHERE users.user_name = :users_user_name ORDER BY users.oid, addresses_4fb8.oid - {'users_user_name': 'jane'} - - for u in users: - print repr(u) - for a in u.addresses: - print repr(a) - -Above, a pretty ambitious query is generated just by specifying that the User should be loaded with its child Addresses in one query. When the mapper processes the results, it uses an *Identity Map* to keep track of objects that were already loaded, based on their primary key identity. Through this method, the redundant rows produced by the join are organized into the distinct object instances they represent. - -Recall that eager loading has no impact on the results of the query. What if our query included our own join criterion? The eager loading query accomodates this using aliases, and is immune to the effects of additional joins being specified in the original query. Joining against the "addresses" table to locate users with a certain street results in this behavior: - - {python} - {sql}users = session.query(User).join('addresses').filter_by(street='123 Green Street').all() - - SELECT users.user_name AS users_user_name, - users.password AS users_password, users.user_id AS users_user_id, - addresses_6ca7.city AS addresses_6ca7_city, - addresses_6ca7.address_id AS addresses_6ca7_address_id, - addresses_6ca7.user_id AS addresses_6ca7_user_id, - addresses_6ca7.zip AS addresses_6ca7_zip, addresses_6ca7.state AS addresses_6ca7_state, - addresses_6ca7.street AS addresses_6ca7_street - FROM users JOIN addresses on users.user_id = addresses.user_id - LEFT OUTER JOIN addresses AS addresses_6ca7 ON users.user_id = addresses_6ca7.user_id - WHERE addresses.street = :addresses_street ORDER BY users.oid, addresses_6ca7.oid - {'addresses_street': '123 Green Street'} - -The join resulting from `join('addresses')` is separate from the join produced by the eager join, which is "aliasized" to prevent conflicts. - -#### Using Options to Change the Loading Strategy {@name=options} - -The `options()` method on the `Query` object is allows modifications to the underlying querying methodology. The most common use of this feature is to change the "eager/lazy" loading behavior of a particular mapper, via the functions `eagerload()`, `lazyload()` and `noload()`: - - {python} - # user mapper with lazy addresses - mapper(User, users_table, properties = { - 'addresses' : relation(mapper(Address, addresses_table)) - } - ) - - # query object - query = session.query(User) - - # make an eager loading query - eagerquery = query.options(eagerload('addresses')) - u = eagerquery.all() - - # make another query that wont load the addresses at all - plainquery = query.options(noload('addresses')) - - # multiple options can be specified - myquery = oldquery.options(lazyload('tracker'), noload('streets'), eagerload('members')) - - # to specify a relation on a relation, separate the property names by a "." - myquery = oldquery.options(eagerload('orders.items')) - -### More Relationships {@name=morerelations} - -Previously, we've discussed how to set up a one-to-many relationship. This section will go over the remaining major types of relationships that can be configured. More detail on on relationships as well as more advanced patterns can be found in [advdatamapping](rel:advdatamapping). - -#### One to One/Many to One {@name=manytoone} - -The above examples focused on the "one-to-many" relationship. To do other forms of relationship is easy, as the `relation()` function can usually figure out what you want: - - {python} - metadata = MetaData() - - # a table to store a user's preferences for a site - prefs_table = Table('user_prefs', metadata, - Column('pref_id', Integer, primary_key = True), - Column('stylename', String(20)), - Column('save_password', Boolean, nullable = False), - Column('timezone', CHAR(3), nullable = False) - ) - - # user table with a 'preference_id' column - users_table = Table('users', metadata, - Column('user_id', Integer, primary_key = True), - Column('user_name', String(16), nullable = False), - Column('password', String(20), nullable = False), - Column('preference_id', Integer, ForeignKey("user_prefs.pref_id")) - ) - - # engine and some test data - engine = create_engine('sqlite:///', echo=True) - metadata.create_all(engine) - engine.execute(prefs_table.insert(), dict(pref_id=1, stylename='green', save_password=1, timezone='EST')) - engine.execute(users_table.insert(), dict(user_name = 'fred', password='45nfss', preference_id=1)) - - # classes - class User(object): - def __init__(self, user_name, password): - self.user_name = user_name - self.password = password - - class UserPrefs(object): - pass - - mapper(UserPrefs, prefs_table) - - mapper(User, users_table, properties = { - 'preferences':relation(UserPrefs, lazy=False, cascade="all, delete-orphan"), - }) - - # select - session = create_session(bind=engine) - {sql}user = session.query(User).filter_by(user_name='fred').one() - SELECT users.preference_id AS users_preference_id, users.user_name AS users_user_name, - users.password AS users_password, users.user_id AS users_user_id, - user_prefs_4eb2.timezone AS user_prefs_4eb2_timezone, user_prefs_4eb2.stylename AS user_prefs_4eb2_stylename, - user_prefs_4eb2.save_password AS user_prefs_4eb2_save_password, user_prefs_4eb2.pref_id AS user_prefs_4eb2_pref_id - FROM (SELECT users.user_id AS users_user_id FROM users WHERE users.user_name = :users_user_name ORDER BY users.oid - LIMIT 1 OFFSET 0) AS rowcount, - users LEFT OUTER JOIN user_prefs AS user_prefs_4eb2 ON user_prefs_4eb2.pref_id = users.preference_id - WHERE rowcount.users_user_id = users.user_id ORDER BY users.oid, user_prefs_4eb2.oid - {'users_user_name': 'fred'} - - save_password = user.preferences.save_password - - # modify - user.preferences.stylename = 'bluesteel' - - # flush - {sql}session.flush() - UPDATE user_prefs SET stylename=:stylename - WHERE user_prefs.pref_id = :pref_id - [{'stylename': 'bluesteel', 'pref_id': 1}] - -#### Many to Many {@name=manytomany} - -The `relation()` function handles a basic many-to-many relationship when you specify an association table using the `secondary` argument: - - {python} - metadata = MetaData() - - articles_table = Table('articles', metadata, - Column('article_id', Integer, primary_key = True), - Column('headline', String(150), key='headline'), - Column('body', TEXT, key='body'), - ) - - keywords_table = Table('keywords', metadata, - Column('keyword_id', Integer, primary_key = True), - Column('keyword_name', String(50)) - ) - - itemkeywords_table = Table('article_keywords', metadata, - Column('article_id', Integer, ForeignKey("articles.article_id")), - Column('keyword_id', Integer, ForeignKey("keywords.keyword_id")) - ) - - engine = create_engine('sqlite:///') - metadata.create_all(engine) - - # class definitions - class Keyword(object): - def __init__(self, name): - self.keyword_name = name - - class Article(object): - pass - - mapper(Keyword, keywords_table) - - # define a mapper that does many-to-many on the 'itemkeywords' association - # table - mapper(Article, articles_table, properties = { - 'keywords':relation(Keyword, secondary=itemkeywords_table, lazy=False) - } - ) - - session = create_session(bind=engine) - - article = Article() - article.headline = 'a headline' - article.body = 'this is the body' - article.keywords.append(Keyword('politics')) - article.keywords.append(Keyword('entertainment')) - session.save(article) - - {sql}session.flush() - INSERT INTO keywords (name) VALUES (:name) - {'name': 'politics'} - INSERT INTO keywords (name) VALUES (:name) - {'name': 'entertainment'} - INSERT INTO articles (article_headline, article_body) VALUES (:article_headline, :article_body) - {'article_body': 'this is the body', 'article_headline': 'a headline'} - INSERT INTO article_keywords (article_id, keyword_id) VALUES (:article_id, :keyword_id) - [{'keyword_id': 1, 'article_id': 1}, {'keyword_id': 2, 'article_id': 1}] - - # select articles based on a keyword. - {sql}articles = session.query(Article).join('keywords').filter_by(keyword_name='politics').all() - SELECT keywords_e2f2.keyword_id AS keywords_e2f2_keyword_id, keywords_e2f2.keyword_name AS keywords_e2f2_keyword_name, - articles.headline AS articles_headline, articles.body AS articles_body, articles.article_id AS articles_article_id - FROM keywords, article_keywords, articles - LEFT OUTER JOIN article_keywords AS article_keyword_3da2 ON articles.article_id = article_keyword_3da2.article_id - LEFT OUTER JOIN keywords AS keywords_e2f2 ON keywords_e2f2.keyword_id = article_keyword_3da2.keyword_id - WHERE (keywords.keyword_name = :keywords_keywords_name AND articles.article_id = article_keywords.article_id) - AND keywords.keyword_id = article_keywords.keyword_id ORDER BY articles.oid, article_keyword_3da2.oid - {'keywords_keyword_name': 'politics'} - - a = articles[0] - - # clear out keywords with a new list - a.keywords = [] - a.keywords.append(Keyword('topstories')) - a.keywords.append(Keyword('government')) - - # flush - {sql}session.flush() - INSERT INTO keywords (name) VALUES (:name) - {'name': 'topstories'} - INSERT INTO keywords (name) VALUES (:name) - {'name': 'government'} - DELETE FROM article_keywords - WHERE article_keywords.article_id = :article_id - AND article_keywords.keyword_id = :keyword_id - [{'keyword_id': 1, 'article_id': 1}, {'keyword_id': 2, 'article_id': 1}] - INSERT INTO article_keywords (article_id, keyword_id) VALUES (:article_id, :keyword_id) - [{'keyword_id': 3, 'article_id': 1}, {'keyword_id': 4, 'article_id': 1}] - -#### Association Object {@name=association} - -Many to Many can also be done with an association object, that adds additional information about how two items are related. In this pattern, the "secondary" option to `relation()` is no longer used; instead, the association object becomes a mapped entity itself, mapped to the association table. If the association table has no explicit primary key columns defined, you also have to tell the mapper what columns will compose its "primary key", which are typically the two (or more) columns involved in the association. Also, the relation between the parent and association mapping is typically set up with a cascade of `all, delete-orphan`. This is to ensure that when an association object is removed from its parent collection, it is deleted (otherwise, the unit of work tries to null out one of the foreign key columns, which raises an error condition since that column is also part of its "primary key"). - - {python} - from sqlalchemy import * - metadata = MetaData() - - users_table = Table('users', metadata, - Column('user_id', Integer, primary_key = True), - Column('user_name', String(16), nullable = False), - ) - - articles_table = Table('articles', metadata, - Column('article_id', Integer, primary_key = True), - Column('headline', String(150), key='headline'), - Column('body', TEXT, key='body'), - ) - - keywords_table = Table('keywords', metadata, - Column('keyword_id', Integer, primary_key = True), - Column('keyword_name', String(50)) - ) - - # add "attached_by" column which will reference the user who attached this keyword - itemkeywords_table = Table('article_keywords', metadata, - Column('article_id', Integer, ForeignKey("articles.article_id")), - Column('keyword_id', Integer, ForeignKey("keywords.keyword_id")), - Column('attached_by', Integer, ForeignKey("users.user_id")) - ) - - engine = create_engine('sqlite:///', echo=True) - metadata.create_all(engine) - - # class definitions - class User(object): - pass - class Keyword(object): - def __init__(self, name): - self.keyword_name = name - class Article(object): - pass - class KeywordAssociation(object): - pass - - # Article mapper, relates to Keyword via KeywordAssociation - mapper(Article, articles_table, properties={ - 'keywords':relation(KeywordAssociation, lazy=False, cascade="all, delete-orphan") - } - ) - - # mapper for KeywordAssociation - # specify "primary key" columns manually - mapper(KeywordAssociation, itemkeywords_table, - primary_key=[itemkeywords_table.c.article_id, itemkeywords_table.c.keyword_id], - properties={ - 'keyword' : relation(Keyword, lazy=False), - 'user' : relation(User, lazy=False) - } - ) - - # user mapper - mapper(User, users_table) - - # keyword mapper - mapper(Keyword, keywords_table) - - session = create_session(bind=engine) - # select by keyword - {sql}alist = session.query(Article).join(['keywords', 'keyword']).filter_by(keyword_name='jacks_stories').all() - SELECT article_keyword_f9af.keyword_id AS article_keyword_f9af_key_b3e1, - article_keyword_f9af.attached_by AS article_keyword_f9af_att_95d4, - article_keyword_f9af.article_id AS article_keyword_f9af_art_fd49, - users_9c30.user_name AS users_9c30_user_name, users_9c30.user_id AS users_9c30_user_id, - keywords_dc54.keyword_id AS keywords_dc54_keyword_id, keywords_dc54.keyword_name AS keywords_dc54_keyword_name, - articles.headline AS articles_headline, articles.body AS articles_body, articles.article_id AS articles_article_id - FROM keywords, article_keywords, articles - LEFT OUTER JOIN article_keywords AS article_keyword_f9af ON articles.article_id = article_keyword_f9af.article_id - LEFT OUTER JOIN users AS users_9c30 ON users_9c30.user_id = article_keyword_f9af.attached_by - LEFT OUTER JOIN keywords AS keywords_dc54 ON keywords_dc54.keyword_id = article_keyword_f9af.keyword_id - WHERE (keywords.keyword_name = :keywords_keywords_name AND keywords.keyword_id = article_keywords.keyword_id) - AND articles.article_id = article_keywords.article_id - ORDER BY articles.oid, article_keyword_f9af.oid, users_9c30.oid, keywords_dc54.oid - {'keywords_keywords_name': 'jacks_stories'} - - # user is available - for a in alist: - for k in a.keywords: - if k.keyword.name == 'jacks_stories': - print k.user.user_name - -Keep in mind that the association object works a little differently from a plain many-to-many relationship. Members have to be added to the list via instances of the association object, which in turn point to the associated object: - - {python} - user = User() - user.user_name = 'some user' - - article = Article() - - assoc = KeywordAssociation() - assoc.keyword = Keyword('blue') - assoc.user = user - - assoc2 = KeywordAssociation() - assoc2.keyword = Keyword('green') - assoc2.user = user - - article.keywords.append(assoc) - article.keywords.append(assoc2) - - session.save(article) - - session.flush() - -SQLAlchemy includes an extension module which can be used in some cases to decrease the explicitness of the association object pattern; this extension is described in [plugins_associationproxy](rel:plugins_associationproxy). - -Note that you should **not** combine the usage of a `secondary` relationship with an association object pattern against the same association table. This is because SQLAlchemy's unit of work will regard rows in the table tracked by the `secondary` argument as distinct from entities mapped into the table by the association mapper, causing unexpected behaviors when rows are changed by one mapping and not the other. diff --git a/doc/build/content/dbengine.txt b/doc/build/content/dbengine.txt index c03d569374..0236a6f70f 100644 --- a/doc/build/content/dbengine.txt +++ b/doc/build/content/dbengine.txt @@ -5,6 +5,7 @@ The **Engine** is the starting point for any SQLAlchemy application. It's "home The general structure is this: + {diagram} +-----------+ __________ /---| Pool |---\ (__________) +-------------+ / +-----------+ \ +--------+ | | @@ -45,9 +46,9 @@ To execute some SQL more quickly, you can skip the `Connection` part and just sa Where above, the `execute()` method on the `Engine` does the `connect()` part for you, and returns the `ResultProxy` directly. The actual `Connection` is *inside* the `ResultProxy`, waiting for you to finish reading the result. In this case, when you `close()` the `ResultProxy`, the underlying `Connection` is closed, which returns the DBAPI connection to the pool. -To summarize the above two examples, when you use a `Connection` object, its known as **explicit execution**. When you don't see the `Connection` object, but you still use the `execute()` method on the `Engine`, its called **explicit, connectionless execution**. A third variant of execution also exists called **implicit execution**; this will be described later. +To summarize the above two examples, when you use a `Connection` object, it's known as **explicit execution**. When you don't see the `Connection` object, but you still use the `execute()` method on the `Engine`, it's called **explicit, connectionless execution**. A third variant of execution also exists called **implicit execution**; this will be described later. -The `Engine` and `Connection` can do a lot more than what we illustrated above; SQL strings are only its most rudimental function. Later chapters will describe how "constructed SQL" expressions can be used with engines; in many cases, you don't have to deal with the `Engine` at all after it's created. The Object Relational Mapper (ORM), an optional feature of SQLAlchemy, also uses the `Engine` in order to get at connections; that's also a case where you can often create the engine once, and then forget about it. +The `Engine` and `Connection` can do a lot more than what we illustrated above; SQL strings are only its most rudimentary function. Later chapters will describe how "constructed SQL" expressions can be used with engines; in many cases, you don't have to deal with the `Engine` at all after it's created. The Object Relational Mapper (ORM), an optional feature of SQLAlchemy, also uses the `Engine` in order to get at connections; that's also a case where you can often create the engine once, and then forget about it. ### Supported Databases {@name=supported} @@ -125,13 +126,14 @@ Keyword options can also be specified to `create_engine()`, following the string A list of all standard options, as well as several that are used by particular database dialects, is as follows: +* **assert_unicode=False** - When set to `True` alongside convert_unicode=`True`, asserts that incoming string bind parameters are instances of `unicode`, otherwise raises an error. Only takes effect when `convert_unicode==True`. This flag is also available on the `String` type and its descendants. New in 0.4.2. * **connect_args** - a dictionary of options which will be passed directly to the DBAPI's `connect()` method as additional keyword arguments. * **convert_unicode=False** - if set to True, all String/character based types will convert Unicode values to raw byte values going into the database, and all raw byte values to Python Unicode coming out in result sets. This is an engine-wide method to provide unicode conversion across the board. For unicode conversion on a column-by-column level, use the `Unicode` column type instead, described in [types](rel:types). * **creator** - a callable which returns a DBAPI connection. This creation function will be passed to the underlying connection pool and will be used to create all new database connections. Usage of this function causes connection parameters specified in the URL argument to be bypassed. * **echo=False** - if True, the Engine will log all statements as well as a repr() of their parameter lists to the engines logger, which defaults to sys.stdout. The `echo` attribute of `Engine` can be modified at any time to turn logging on and off. If set to the string `"debug"`, result rows will be printed to the standard output as well. This flag ultimately controls a Python logger; see [dbengine_logging](rel:dbengine_logging) at the end of this chapter for information on how to configure logging directly. * **echo_pool=False** - if True, the connection pool will log all checkouts/checkins to the logging stream, which defaults to sys.stdout. This flag ultimately controls a Python logger; see [dbengine_logging](rel:dbengine_logging) for information on how to configure logging directly. * **encoding='utf-8'** - the encoding to use for all Unicode translations, both by engine-wide unicode conversion as well as the `Unicode` type object. -* **module=None** - used by database implementations which support multiple DBAPI modules, this is a reference to a DBAPI2 module to be used instead of the engine's default module. For Postgres, the default is psycopg2. For Oracle, its cx_Oracle. +* **module=None** - used by database implementations which support multiple DBAPI modules, this is a reference to a DBAPI2 module to be used instead of the engine's default module. For Postgres, the default is psycopg2. For Oracle, it's cx_Oracle. * **pool=None** - an already-constructed instance of `sqlalchemy.pool.Pool`, such as a `QueuePool` instance. If non-None, this pool will be used directly as the underlying connection pool for the engine, bypassing whatever connection parameters are present in the URL argument. For information on constructing connection pools manually, see [pooling](rel:pooling). * **poolclass=None** - a `sqlalchemy.pool.Pool` subclass, which will be used to create a connection pool instance using the connection parameters given in the URL. Note this differs from `pool` in that you don't actually instantiate the pool in this case, you just indicate what type of pool to be used. * **max_overflow=10** - the number of connections to allow in connection pool "overflow", that is connections that can be opened above and beyond the pool_size setting, which defaults to five. this is only used with `QueuePool`. @@ -140,7 +142,7 @@ A list of all standard options, as well as several that are used by particular d * **pool_timeout=30** - number of seconds to wait before giving up on getting a connection from the pool. This is only used with `QueuePool`. * **strategy='plain'** - the Strategy argument is used to select alternate implementations of the underlying Engine object, which coordinates operations between dialects, compilers, connections, and so on. Currently, the only alternate strategy besides the default value of "plain" is the "threadlocal" strategy, which selects the usage of the `TLEngine` class that provides a modified connection scope for connectionless executions. Connectionless execution as well as further detail on this setting are described in [dbengine_implicit](rel:dbengine_implicit). * **threaded=True** - used by cx_Oracle; sets the `threaded` parameter of the connection indicating thread-safe usage. cx_Oracle docs indicate setting this flag to `False` will speed performance by 10-15%. While this defaults to `False` in cx_Oracle, SQLAlchemy defaults it to `True`, preferring stability over early optimization. -* **use_ansi=True** - used only by Oracle; when False, the Oracle driver attempts to support a particular "quirk" of Oracle versions 8 and previous, that the LEFT OUTER JOIN SQL syntax is not supported, and the "Oracle join" syntax of using `<column1>(+)=<column2>` must be used in order to achieve a LEFT OUTER JOIN. +* **use_ansi=True** - used only by Oracle; when False, the Oracle driver attempts to support a particular "quirk" of Oracle versions 8 and previous, that the LEFT OUTER JOIN SQL syntax is not supported, and the "Oracle join" syntax of using `column1(+)=column2` must be used in order to achieve a LEFT OUTER JOIN. * **use_oids=False** - used only by Postgres, will enable the column name "oid" as the object ID column, which is also used for the default sort order of tables. Postgres as of 8.1 has object IDs disabled by default. ### More On Connections {@name=connections} @@ -241,9 +243,9 @@ The above transaction example illustrates how to use `Transaction` so that sever ### Connectionless Execution, Implicit Execution {@name=implicit} -Recall from the first section we mentioned executing with and without a `Connection`. `Connectionless` execution refers to calling the `execute()` method on an object which is not a `Connection`, which could be on the the `Engine` itself, or could be a constructed SQL object. When we say "implicit", we mean that we are calling the `execute()` method on an object which is neither a `Connection` nor an `Engine` object; this can only be used with constructed SQL objects which have their own `execute()` method, and can be "bound" to an `Engine`. A description of "constructed SQL objects" may be found in [sql](rel:sql). +Recall from the first section we mentioned executing with and without a `Connection`. `Connectionless` execution refers to calling the `execute()` method on an object which is not a `Connection`, which could be on the `Engine` itself, or could be a constructed SQL object. When we say "implicit", we mean that we are calling the `execute()` method on an object which is neither a `Connection` nor an `Engine` object; this can only be used with constructed SQL objects which have their own `execute()` method, and can be "bound" to an `Engine`. A description of "constructed SQL objects" may be found in [sql](rel:sql). -A summary of all three methods follows below. First, assume the usage of the following `MetaData` and `Table` objects; while we haven't yet introduced these concepts, for now you only need to know that we are representing a database table, and are creating an "executeable" SQL construct which issues a statement to the database. These objects are described in [metadata](rel:metadata). +A summary of all three methods follows below. First, assume the usage of the following `MetaData` and `Table` objects; while we haven't yet introduced these concepts, for now you only need to know that we are representing a database table, and are creating an "executable" SQL construct which issues a statement to the database. These objects are described in [metadata](rel:metadata). {python} meta = MetaData() @@ -275,7 +277,7 @@ Implicit execution is also connectionless, and calls the `execute()` method on t {python} engine = create_engine('sqlite:///file.db') - meta.connect(engine) + meta.bind = engine result = users_table.select().execute() for row in result: # .... @@ -333,7 +335,7 @@ The usage of "threadlocal" modifies the underlying behavior of our example above Where above, we again have two result sets in scope at the same time, but because they are present in the same thread, there is only **one DBAPI connection in use**. -While the above distinction may not seem like much, it has several potentially desireable effects. One is that you can in some cases reduce the number of concurrent connections checked out from the connection pool, in the case that a `ResultProxy` is still opened and a second statement is issued. A second advantage is that by limiting the number of checked out connections in a thread to just one, you eliminate the issue of deadlocks within a single thread, such as when connection A locks a table, and connection B attempts to read from the same table in the same thread, it will "deadlock" on waiting for connection A to release its lock; the `threadlocal` strategy eliminates this possibility. +While the above distinction may not seem like much, it has several potentially desirable effects. One is that you can in some cases reduce the number of concurrent connections checked out from the connection pool, in the case that a `ResultProxy` is still opened and a second statement is issued. A second advantage is that by limiting the number of checked out connections in a thread to just one, you eliminate the issue of deadlocks within a single thread, such as when connection A locks a table, and connection B attempts to read from the same table in the same thread, it will "deadlock" on waiting for connection A to release its lock; the `threadlocal` strategy eliminates this possibility. A third advantage to the `threadlocal` strategy is that it allows the `Transaction` object to be used in combination with connectionless execution. Recall from the section on transactions, that the `Transaction` is returned by the `begin()` method on a `Connection`; all statements which wish to participate in this transaction must be executed by the same `Connection`, thereby forcing the usage of an explicit connection. However, the `TLEngine` provides a `Transaction` that is local to the current thread; using it, one can issue many "connectionless" statements within a thread and they will all automatically partake in the current transaction, as in the example below: @@ -382,7 +384,7 @@ Complex application flows can take advantage of the "threadlocal" strategy in or except: engine.rollback() -In the above example, the program calls three functions `dosomethingimplicit()`, `dosomethingelse()` and `dosomethingtransactional()`. All three functions use either connectionless execution, or a special function `contextual_connect()` which we will describe in a moment. These two styles of execution both indicate that all executions will use the same connection object. Additionally, the method `dosomethingtransactional()` begins and commits its own `Transaction`. But only one transaction is used, too; it's controlled completely by the `engine.begin()`/`engine.commit()` calls at the bottom. Recall that `Transaction` supports "nesting" behavior, whereby transactions begun on a `Connection` which already has a tranasaction open, will "nest" into the enclosing transaction. Since the transaction opened in `dosomethingtransactional()` occurs using the same connection which already has a transaction begun, it "nests" into that transaction and therefore has no effect on the actual transaction scope (unless it calls `rollback()`). +In the above example, the program calls three functions `dosomethingimplicit()`, `dosomethingelse()` and `dosomethingtransactional()`. All three functions use either connectionless execution, or a special function `contextual_connect()` which we will describe in a moment. These two styles of execution both indicate that all executions will use the same connection object. Additionally, the method `dosomethingtransactional()` begins and commits its own `Transaction`. But only one transaction is used, too; it's controlled completely by the `engine.begin()`/`engine.commit()` calls at the bottom. Recall that `Transaction` supports "nesting" behavior, whereby transactions begun on a `Connection` which already has a transaction open, will "nest" into the enclosing transaction. Since the transaction opened in `dosomethingtransactional()` occurs using the same connection which already has a transaction begun, it "nests" into that transaction and therefore has no effect on the actual transaction scope (unless it calls `rollback()`). Some of the functions in the above example make use of a method called `engine.contextual_connect()`. This method is available on both `Engine` as well as `TLEngine`, and returns the `Connection` that applies to the current **connection context**. When using the `TLEngine`, this is just another term for the "thread local connection" that is being used for all connectionless executions. When using just the regular `Engine` (i.e. the "default" strategy), `contextual_connect()` is synonymous with `connect()`. Below we illustrate that two connections opened via `contextual_connect()` at the same time, both reference the same underlying DBAPI connection: @@ -396,7 +398,7 @@ Some of the functions in the above example make use of a method called `engine.c >>> conn1.connection is conn2.connection True -The basic idea of `contextual_connect()` is that its the "connection used by connectionless execution". It's different from the `connect()` method in that `connect()` is always used when handling an explicit `Connection`, which will always reference distinct DBAPI connection. Using `connect()` in combination with `TLEngine` allows one to "circumvent" the current thread local context, as in this example where a single statement issues data to the database externally to the current transaction: +The basic idea of `contextual_connect()` is that it's the "connection used by connectionless execution". It's different from the `connect()` method in that `connect()` is always used when handling an explicit `Connection`, which will always reference distinct DBAPI connection. Using `connect()` in combination with `TLEngine` allows one to "circumvent" the current thread local context, as in this example where a single statement issues data to the database externally to the current transaction: {python} engine.begin() diff --git a/doc/build/content/intro.txt b/doc/build/content/intro.txt new file mode 100644 index 0000000000..d6ded5bdf8 --- /dev/null +++ b/doc/build/content/intro.txt @@ -0,0 +1,166 @@ +Overview / Installation +============ + +## Overview + +The SQLAlchemy SQL Toolkit and Object Relational Mapper is a comprehensive set of tools for working with databases and Python. It has several distinct areas of functionality which can be used individually or combined together. Its major API components, all public-facing, are illustrated below: + + {diagram} + +-----------------------------------------------------------+ + | Object Relational Mapper (ORM) | + | [[tutorial]](rel:datamapping) [[docs]](rel:advdatamapping) | + +-----------------------------------------------------------+ + +---------+ +------------------------------------+ +--------+ + | | | SQL Expression Language | | | + | | | [[tutorial]](rel:sql) [[docs]](rel:docstrings_sqlalchemy.sql.expression) | | | + | | +------------------------------------+ | | + | +-----------------------+ +--------------+ | + | Dialect/Execution | | Schema Management | + | [[docs]](rel:dbengine) | | [[docs]](rel:metadata) | + +---------------------------------+ +-----------------------+ + +----------------------+ +----------------------------------+ + | Connection Pooling | | Types | + | [[docs]](rel:pooling) | | [[docs]](rel:types) | + +----------------------+ +----------------------------------+ + +Above, the two most significant front-facing portions of SQLAlchemy are the **Object Relational Mapper** and the **SQL Expression Language**. These are two separate toolkits, one building off the other. SQL Expressions can be used independently of the ORM. When using the ORM, the SQL Expression language is used to establish object-relational configurations as well as in querying. + +## Tutorials + + * [Object Relational Tutorial](rel:datamapping) - This describes the richest feature of SQLAlchemy, its object relational mapper. If you want to work with higher-level SQL which is constructed automatically for you, as well as management of Python objects, proceed to this tutorial. + * [SQL Expression Tutorial](rel:sql) - The core of SQLAlchemy is its SQL expression language. The SQL Expression Language is a toolkit all its own, independent of the ORM package, which can be used to construct manipulable SQL expressions which can be programmatically constructed, modified, and executed, returning cursor-like result sets. It's a lot more lightweight than the ORM and is appropriate for higher scaling SQL operations. It's also heavily present within the ORM's public facing API, so advanced ORM users will want to master this language as well. + +## Reference Documentation + + * [Datamapping](rel:advdatamapping) - A comprehensive walkthrough of major ORM patterns and techniques. + * [Session](rel:unitofwork) - A detailed description of SQLAlchemy's Session object + * [Engines](rel:dbengine) - Describes SQLAlchemy's database-connection facilities, including connection documentation and working with connections and transactions. + * [Connection Pools](rel:pooling) - Further detail about SQLAlchemy's connection pool library. + * [Metadata](rel:metadata) - All about schema management using `MetaData` and `Table` objects; reading database schemas into your application, creating and dropping tables, constraints, defaults, sequences, indexes. + * [Types](rel:types) - Datatypes included with SQLAlchemy, their functions, as well as how to create your own types. + * [Plugins](rel:plugins) - Included addons for SQLAlchemy + +## Installing SQLAlchemy {@name=sqlalchemy} + +Installing SQLAlchemy from scratch is most easily achieved with [setuptools][]. ([setuptools installation][install setuptools]). Just run this from the command-line: + + # easy_install SQLAlchemy + +This command will download the latest version of SQLAlchemy from the [Python Cheese Shop][pypi] and install it to your system. + +[setuptools]: http://peak.telecommunity.com/DevCenter/setuptools +[install setuptools]: http://peak.telecommunity.com/DevCenter/EasyInstall#installation-instructions +[pypi]: http://pypi.python.org/pypi/SQLAlchemy + +Otherwise, you can install from the distribution using the `setup.py` script: + + # python setup.py install + +### Installing a Database API {@name=dbms} + +SQLAlchemy is designed to operate with a [DB-API](http://www.python.org/doc/peps/pep-0249/) implementation built for a particular database, and includes support for the most popular databases: + +* Postgres: [psycopg2](http://www.initd.org/tracker/psycopg/wiki/PsycopgTwo) +* SQLite: [pysqlite](http://initd.org/tracker/pysqlite), [sqlite3](http://docs.python.org/lib/module-sqlite3.html) (included with Python 2.5 or greater) +* MySQL: [MySQLdb](http://sourceforge.net/projects/mysql-python) +* Oracle: [cx_Oracle](http://www.cxtools.net/default.aspx?nav=home) +* MS-SQL: [pyodbc](http://pyodbc.sourceforge.net/) (recommended), [adodbapi](http://adodbapi.sourceforge.net/) or [pymssql](http://pymssql.sourceforge.net/) +* Firebird: [kinterbasdb](http://kinterbasdb.sourceforge.net/) +* Informix: [informixdb](http://informixdb.sourceforge.net/) + +### Checking the Installed SQLAlchemy Version + +This documentation covers SQLAlchemy version 0.4. If you're working on a system that already has SQLAlchemy installed, check the version from your Python prompt like this: + + {python} + >>> import sqlalchemy + >>> sqlalchemy.__version__ # doctest: +SKIP + 0.4.0 + +## 0.3 to 0.4 Migration {@name=migration} + +From version 0.3 to version 0.4 of SQLAlchemy, some conventions have changed. Most of these conventions are available in the most recent releases of the 0.3 series starting with version 0.3.9, so that you can make a 0.3 application compatible with 0.4 in most cases. + +This section will detail only those things that have changed in a backwards-incompatible manner. For a full overview of everything that's new and changed, see [WhatsNewIn04](http://www.sqlalchemy.org/trac/wiki/WhatsNewIn04). + +### ORM Package is now sqlalchemy.orm {@name=imports} + +All symbols related to the SQLAlchemy Object Relational Mapper, i.e. names like `mapper()`, `relation()`, `backref()`, `create_session()` `synonym()`, `eagerload()`, etc. are now only in the `sqlalchemy.orm` package, and **not** in `sqlalchemy`. So if you were previously importing everything on an asterisk: + + {python} + from sqlalchemy import * + +You should now import separately from orm: + + {python} + from sqlalchemy import * + from sqlalchemy.orm import * + +Or more commonly, just pull in the names you'll need: + + {python} + from sqlalchemy import create_engine, MetaData, Table, Column, types + from sqlalchemy.orm import mapper, relation, backref, create_session + +### BoundMetaData is now MetaData {@name=metadata} + +The `BoundMetaData` name is removed. Now, you just use `MetaData`. Additionally, the `engine` parameter/attribute is now called `bind`, and `connect()` is deprecated: + + {python} + # plain metadata + meta = MetaData() + + # metadata bound to an engine + meta = MetaData(engine) + + # bind metadata to an engine later + meta.bind = engine + +Additionally, `DynamicMetaData` is now known as `ThreadLocalMetaData`. + +### "Magic" Global MetaData removed {@name=global} + +There was an old way to specify `Table` objects using an implicit, global `MetaData` object. To do this you'd omit the second positional argument, and specify `Table('tablename', Column(...))`. This no longer exists in 0.4 and the second `MetaData` positional argument is required, i.e. `Table('tablename', meta, Column(...))`. + +### Some existing select() methods become generative {@name=generative} + +The methods `correlate()`, `order_by()`, and `group_by()` on the `select()` construct now return a **new** select object, and do not change the original one. Additionally, the generative methods `where()`, `column()`, `distinct()`, and several others have been added: + + {python} + s = table.select().order_by(table.c.id).where(table.c.x==7) + result = engine.execute(s) + +### collection_class behavior is changed {@name=collection} + +If you've been using the `collection_class` option on `mapper()`, the requirements for instrumented collections have changed. For an overview, see [advdatamapping_relation_collections](rel:advdatamapping_relation_collections). + +### All "engine", "bind_to", "connectable" Keyword Arguments Changed to "bind" {@name=bind} + +This is for create/drop statements, sessions, SQL constructs, metadatas: + + {python} + myengine = create_engine('sqlite://') + + meta = MetaData(myengine) + + meta2 = MetaData() + meta2.bind = myengine + + session = create_session(bind=myengine) + + statement = select([table], bind=myengine) + + meta.create_all(bind=myengine) + +### All "type" Keyword Arguments Changed to "type_" {@name=type} + +This mostly applies to SQL constructs where you pass a type in: + + {python} + s = select([mytable], mytable.c.x=bindparam(y, type_=DateTime)) + + func.now(type_=DateTime) + +### Mapper Extensions must return EXT_CONTINUE to continue execution to the next mapper + +If you extend the mapper, the methods in your mapper extension must return EXT_CONTINUE to continue executing additional mappers. diff --git a/doc/build/content/mappers.txt b/doc/build/content/mappers.txt new file mode 100644 index 0000000000..fca2076bc7 --- /dev/null +++ b/doc/build/content/mappers.txt @@ -0,0 +1,1515 @@ +[alpha_api]: javascript:alphaApi() +[alpha_implementation]: javascript:alphaImplementation() + +Mapper Configuration {@name=advdatamapping} +====================== + +This section references most major configurational patterns involving the [mapper()](rel:docstrings_sqlalchemy.orm_modfunc_mapper) and [relation()](rel:docstrings_sqlalchemy.orm_modfunc_relation) functions. It assumes you've worked through the [datamapping](rel:datamapping) and know how to construct and use rudimentary mappers and relations. + +### Mapper Configuration + +Full API documentation for the ORM: + +[docstrings_sqlalchemy.orm](rel:docstrings_sqlalchemy.orm). + +Options for the `mapper()` function: + +[docstrings_sqlalchemy.orm_modfunc_mapper](rel:docstrings_sqlalchemy.orm_modfunc_mapper). + +#### Customizing Column Properties {@name=columns} + +The default behavior of a `mapper` is to assemble all the columns in the mapped `Table` into mapped object attributes. This behavior can be modified in several ways, as well as enhanced by SQL expressions. + +To load only a part of the columns referenced by a table as attributes, use the `include_properties` and `exclude_properties` arguments: + + {python} + mapper(User, users_table, include_properties=['user_id', 'user_name']) + + mapper(Address, addresses_table, exclude_properties=['street', 'city', 'state', 'zip']) + +To change the name of the attribute mapped to a particular column, place the `Column` object in the `properties` dictionary with the desired key: + + {python} + mapper(User, users_table, properties={ + 'id' : users_table.c.user_id, + 'name' : users_table.c.user_name, + }) + +To change the names of all attributes using a prefix, use the `column_prefix` option. This is useful for classes which wish to add their own `property` accessors: + + {python} + mapper(User, users_table, column_prefix='_') + +The above will place attribute names such as `_user_id`, `_user_name`, `_password` etc. on the mapped `User` class. + +To place multiple columns which are known to be "synonymous" based on foreign key relationship or join condition into the same mapped attribute, put them together using a list, as below where we map to a `Join`: + + {python} + # join users and addresses + usersaddresses = sql.join(users_table, addresses_table, \ + users_table.c.user_id == addresses_table.c.user_id) + + mapper(User, usersaddresses, + properties = { + 'id':[users_table.c.user_id, addresses_table.c.user_id], + }) + +#### Deferred Column Loading {@name=deferred} + +This feature allows particular columns of a table to not be loaded by default, instead being loaded later on when first referenced. It is essentially "column-level lazy loading". This feature is useful when one wants to avoid loading a large text or binary field into memory when it's not needed. Individual columns can be lazy loaded by themselves or placed into groups that lazy-load together. + + {python} + book_excerpts = Table('books', db, + Column('book_id', Integer, primary_key=True), + Column('title', String(200), nullable=False), + Column('summary', String(2000)), + Column('excerpt', String), + Column('photo', Binary) + ) + + class Book(object): + pass + + # define a mapper that will load each of 'excerpt' and 'photo' in + # separate, individual-row SELECT statements when each attribute + # is first referenced on the individual object instance + mapper(Book, book_excerpts, properties = { + 'excerpt' : deferred(book_excerpts.c.excerpt), + 'photo' : deferred(book_excerpts.c.photo) + }) + +Deferred columns can be placed into groups so that they load together: + + {python} + book_excerpts = Table('books', db, + Column('book_id', Integer, primary_key=True), + Column('title', String(200), nullable=False), + Column('summary', String(2000)), + Column('excerpt', String), + Column('photo1', Binary), + Column('photo2', Binary), + Column('photo3', Binary) + ) + + class Book(object): + pass + + # define a mapper with a 'photos' deferred group. when one photo is referenced, + # all three photos will be loaded in one SELECT statement. The 'excerpt' will + # be loaded separately when it is first referenced. + mapper(Book, book_excerpts, properties = { + 'excerpt' : deferred(book_excerpts.c.excerpt), + 'photo1' : deferred(book_excerpts.c.photo1, group='photos'), + 'photo2' : deferred(book_excerpts.c.photo2, group='photos'), + 'photo3' : deferred(book_excerpts.c.photo3, group='photos') + }) + +You can defer or undefer columns at the `Query` level using the `defer` and `undefer` options: + + {python} + query = session.query(Book) + query.options(defer('summary')).all() + query.options(undefer('excerpt')).all() + +And an entire "deferred group", i.e. which uses the `group` keyword argument to `deferred()`, can be undeferred using `undefer_group()`, sending in the group name: + + {python} + query = session.query(Book) + query.options(undefer_group('photos')).all() + +#### SQL Expressions as Mapped Attributes {@name=expressions} + +To add a SQL clause composed of local or external columns as a read-only, mapped column attribute, use the `column_property()` function. Any scalar-returning `ClauseElement` may be used, as long as it has a `name` attribute; usually, you'll want to call `label()` to give it a specific name: + + {python} + mapper(User, users_table, properties={ + 'fullname' : column_property( + (users_table.c.firstname + " " + users_table.c.lastname).label('fullname') + ) + }) + +Correlated subqueries may be used as well: + + {python} + mapper(User, users_table, properties={ + 'address_count' : column_property( + select( + [func.count(addresses_table.c.address_id)], + addresses_table.c.user_id==users_table.c.user_id + ).label('address_count') + ) + }) + +#### Overriding Attribute Behavior with Synonyms {@name=overriding} + +A common request is the ability to create custom class properties that override the behavior of setting/getting an attribute. As of 0.4.2, the `synonym()` construct provides an easy way to do this in conjunction with a normal Python `property` constructs. Below, we re-map the `email` column of our mapped table to a custom attribute setter/getter, mapping the actual column to the property named `_email`: + + {python} + class MyAddress(object): + def _set_email(self, email): + self._email = email + def _get_email(self): + return self._email + email = property(_get_email, _set_email) + + mapper(MyAddress, addresses_table, properties = { + 'email':synonym('_email', map_column=True) + }) + +The `email` attribute is now usable in the same way as any other mapped attribute, including filter expressions, get/set operations, etc.: + + {python} + address = sess.query(MyAddress).filter(MyAddress.email == 'some address').one() + + address.email = 'some other address' + sess.flush() + + q = sess.query(MyAddress).filter_by(email='some other address') + +If the mapped class does not provide a property, the `synonym()` construct will create a default getter/setter object automatically. + +#### Composite Column Types {@name=composite} + +Sets of columns can be associated with a single datatype. The ORM treats the group of columns like a single column which accepts and returns objects using the custom datatype you provide. In this example, we'll create a table `vertices` which stores a pair of x/y coordinates, and a custom datatype `Point` which is a composite type of an x and y column: + + {python} + vertices = Table('vertices', metadata, + Column('id', Integer, primary_key=True), + Column('x1', Integer), + Column('y1', Integer), + Column('x2', Integer), + Column('y2', Integer), + ) + +The requirements for the custom datatype class are that it have a constructor which accepts positional arguments corresponding to its column format, and also provides a method `__composite_values__()` which returns the state of the object as a list or tuple, in order of its column-based attributes. It also should supply adequate `__eq__()` and `__ne__()` methods which test the equality of two instances: + + {python} + class Point(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __composite_values__(self): + return [self.x, self.y] + def __eq__(self, other): + return other.x == self.x and other.y == self.y + def __ne__(self, other): + return not self.__eq__(other) + +Setting up the mapping uses the `composite()` function: + + + {python} + class Vertex(object): + pass + + mapper(Vertex, vertices, properties={ + 'start':composite(Point, vertices.c.x1, vertices.c.y1), + 'end':composite(Point, vertices.c.x2, vertices.c.y2) + }) + +We can now use the `Vertex` instances as well as querying as though the `start` and `end` attributes are regular scalar attributes: + + {python} + sess = Session() + v = Vertex(Point(3, 4), Point(5, 6)) + sess.save(v) + + v2 = sess.query(Vertex).filter(Vertex.start == Point(3, 4)) + +The "equals" comparison operation by default produces an AND of all corresponding columns equated to one another. If you'd like to override this, or define the behavior of other SQL operators for your new type, the `composite()` function accepts an extension object of type `sqlalchemy.orm.PropComparator`: + + {python} + from sqlalchemy.orm import PropComparator + from sqlalchemy import sql + + class PointComparator(PropComparator): + def __gt__(self, other): + """define the 'greater than' operation""" + + return sql.and_(*[a>b for a, b in + zip(self.prop.columns, + other.__composite_values__())]) + + maper(Vertex, vertices, properties={ + 'start':composite(Point, vertices.c.x1, vertices.c.y1, comparator=PointComparator), + 'end':composite(Point, vertices.c.x2, vertices.c.y2, comparator=PointComparator) + }) + +#### Controlling Ordering {@name=orderby} + +By default, mappers will attempt to ORDER BY the "oid" column of a table, or the first primary key column, when selecting rows. This can be modified in several ways. + +The "order_by" parameter can be sent to a mapper, overriding the per-engine ordering if any. A value of None means that the mapper should not use any ordering. A non-None value, which can be a column, an `asc` or `desc` clause, or an array of either one, indicates the ORDER BY clause that should be added to all select queries: + + {python} + # disable all ordering + mapper(User, users_table, order_by=None) + + # order by a column + mapper(User, users_table, order_by=users_table.c.user_id) + + # order by multiple items + mapper(User, users_table, order_by=[users_table.c.user_id, users_table.c.user_name.desc()]) + +"order_by" can also be specified with queries, overriding all other per-engine/per-mapper orderings: + + {python} + # order by a column + l = query.filter(User.user_name=='fred').order_by(User.user_id).all() + + # order by multiple criterion + l = query.filter(User.user_name=='fred').order_by([User.user_id, User.user_name.desc()]) + +The "order_by" property can also be specified on a `relation()` which will control the ordering of the collection: + + {python} + mapper(Address, addresses_table) + + # order address objects by address id + mapper(User, users_table, properties = { + 'addresses' : relation(Address, order_by=addresses_table.c.address_id) + }) + +Note that when using eager loaders with relations, the tables used by the eager load's join are anonymously aliased. You can only order by these columns if you specify it at the `relation()` level. To control ordering at the query level based on a related table, you `join()` to that relation, then order by it: + + {python} + session.query(User).join('addresses').order_by(Address.street) + +#### Mapping Class Inheritance Hierarchies {@name=inheritance} + +SQLAlchemy supports three forms of inheritance: *single table inheritance*, where several types of classes are stored in one table, *concrete table inheritance*, where each type of class is stored in its own table, and *joined table inheritance*, where the parent/child classes are stored in their own tables that are joined together in a select. Whereas support for single and joined table inheritance is strong, concrete table inheritance is a less common scenario with some particular problems so is not quite as flexible. + +When mappers are configured in an inheritance relationship, SQLAlchemy has the ability to load elements "polymorphically", meaning that a single query can return objects of multiple types. + +For the following sections, assume this class relationship: + + {python} + class Employee(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return self.__class__.__name__ + " " + self.name + + class Manager(Employee): + def __init__(self, name, manager_data): + self.name = name + self.manager_data = manager_data + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.manager_data + + class Engineer(Employee): + def __init__(self, name, engineer_info): + self.name = name + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.engineer_info + +##### Joined Table Inheritance {@name=joined} + +In joined table inheritance, each class along a particular classes' list of parents is represented by a unique table. The total set of attributes for a particular instance is represented as a join along all tables in its inheritance path. Here, we first define a table to represent the `Employee` class. This table will contain a primary key column (or columns), and a column for each attribute that's represented by `Employee`. In this case it's just `name`: + + {python} + employees = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('type', String(30), nullable=False) + ) + +The table also has a column called `type`. It is strongly advised in both single- and joined- table inheritance scenarios that the root table contains a column whose sole purpose is that of the **discriminator**; it stores a value which indicates the type of object represented within the row. The column may be of any desired datatype. While there are some "tricks" to work around the requirement that there be a discriminator column, they are more complicated to configure when one wishes to load polymorphically. + +Next we define individual tables for each of `Engineer` and `Manager`, which each contain columns that represent the attributes unique to the subclass they represent. Each table also must contain a primary key column (or columns), and in most cases a foreign key reference to the parent table. It is standard practice that the same column is used for both of these roles, and that the column is also named the same as that of the parent table. However this is optional in SQLAlchemy; separate columns may be used for primary key and parent-relation, the column may be named differently than that of the parent, and even a custom join condition can be specified between parent and child tables instead of using a foreign key. In joined table inheritance, the primary key of an instance is always represented by the primary key of the base table only (new in SQLAlchemy 0.4). + + {python} + engineers = Table('engineers', metadata, + Column('employee_id', Integer, ForeignKey('employees.employee_id'), primary_key=True), + Column('engineer_info', String(50)), + ) + + managers = Table('managers', metadata, + Column('employee_id', Integer, ForeignKey('employees.employee_id'), primary_key=True), + Column('manager_data', String(50)), + ) + +We then configure mappers as usual, except we use some additional arguments to indicate the inheritance relationship, the polymorphic discriminator column, and the **polymorphic identity** of each class; this is the value that will be stored in the polymorphic discriminator column. + + {python} + mapper(Employee, employees, polymorphic_on=employees.c.type, polymorphic_identity='employee') + mapper(Engineer, engineers, inherits=Employee, polymorphic_identity='engineer') + mapper(Manager, managers, inherits=Employee, polymorphic_identity='manager') + +And that's it. Querying against `Employee` will return a combination of `Employee`, `Engineer` and `Manager` objects. + +###### Polymorphic Querying Strategies {@name=querying} + +The `Query` object includes some helper functionality when dealing with joined-table inheritance mappings. These are the `with_polymorphic()` and `of_type()` methods, both of which are introduced in version 0.4.4. + +The `with_polymorphic()` method affects the specific subclass tables which the Query selects from. Normally, a query such as this: + + {python} + session.query(Employee).filter(Employee.name=='ed') + +Selects only from the `employees` table. The criterion we use in `filter()` and other methods will generate WHERE criterion against this table. What if we wanted to load `Employee` objects but also wanted to use criterion against `Engineer` ? We could just query against the `Engineer` class instead. But, if we were using criterion which filters among more than one subclass (subclasses which do not inherit directly from one to the other), we'd like to select from an outer join of all those tables. The `with_polymorphic()` method can tell `Query` which joined-table subclasses we want to select for: + + {python} + session.query(Employee).with_polymorphic(Engineer).filter(Engineer.engineer_info=='some info') + +Even without criterion, the `with_polymorphic()` method has the added advantage that instances are loaded from all of their tables in one result set. Such as, to optimize the loading of all `Employee` objects, `with_polymorphic()` accepts `'*'` as a wildcard indicating that all subclass tables should be joined: + + {python} + session.query(Employee).with_polymorphic('*').all() + +`with_polymorphic()` is an effective query-level alternative to the existing `select_table` option available on `mapper()`. + +Next is a way to join along `relation` paths while narrowing the criterion to specific subclasses. Suppose the `employees` table represents a collection of employees which are associated with a `Company` object. We'll add a `company_id` column to the `employees` table and a new table `companies`: + + {python} + companies = Table('companies', metadata, + Column('company_id', Integer, primary_key=True), + Column('name', String(50)) + ) + + employees = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('type', String(30), nullable=False), + Column('company_id', Integer, ForeignKey('companies.company_id')) + ) + + class Company(object): + pass + + mapper(Company, companies, properties={ + 'employees':relation(Employee) + }) + +If we wanted to join from `Company` to not just `Employee` but specifically `Engineers`, using the `join()` method or `any()` or `has()` operators will by default create a join from `companies` to `employees`, without including `engineers` or `managers` in the mix. If we wish to have criterion which is specifically against the `Engineer` class, we can tell those methods to join or subquery against the full set of tables representing the subclass using the `of_type()` opertator: + + {python} + session.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.engineer_info=='someinfo') + +A longhand notation, introduced in 0.4.3, is also available, which involves spelling out the full target selectable within a 2-tuple: + + {python} + session.query(Company).join(('employees', employees.join(engineers))).filter(Engineer.engineer_info=='someinfo') + +The second notation allows more flexibility, such as joining to any group of subclass tables: + + {python} + session.query(Company).join(('employees', employees.outerjoin(engineers).outerjoin(managers))).\ + filter(or_(Engineer.engineer_info=='someinfo', Manager.manager_data=='somedata')) + +The `any()` and `has()` operators also can be used with `of_type()` when the embedded criterion is in terms of a subclass: + + {python} + session.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.engineer_info=='someinfo')).all() + +Note that the `any()` and `has()` are both shorthand for a correlated EXISTS query. To build one by hand looks like: + + {python} + session.query(Company).filter( + exists([1], + and_(Engineer.engineer_info=='someinfo', employees.c.company_id==companies.c.company_id), + from_obj=employees.join(engineers) + ) + ).all() + +The EXISTS subquery above selects from the join of `employees` to `engineers`, and also specifies criterion which correlates the EXISTS subselect back to the parent `companies` table. + +###### Optimizing Joined Table Loads {@name=optimizing} + +When loading fresh from the database, the joined-table setup above will query from the parent table first, then for each row will issue a second query to the child table. For example, for a load of five rows with `Employee` id 3, `Manager` ids 1 and 5 and `Engineer` ids 2 and 4, will produce queries along the lines of this example: + + {python} + session.query(Employee).all() + {opensql} + SELECT employees.employee_id AS employees_employee_id, employees.name AS employees_name, employees.type AS employees_type + FROM employees ORDER BY employees.oid + [] + SELECT managers.employee_id AS managers_employee_id, managers.manager_data AS managers_manager_data + FROM managers + WHERE ? = managers.employee_id + [5] + SELECT engineers.employee_id AS engineers_employee_id, engineers.engineer_info AS engineers_engineer_info + FROM engineers + WHERE ? = engineers.employee_id + [2] + SELECT engineers.employee_id AS engineers_employee_id, engineers.engineer_info AS engineers_engineer_info + FROM engineers + WHERE ? = engineers.employee_id + [4] + SELECT managers.employee_id AS managers_employee_id, managers.manager_data AS managers_manager_data + FROM managers + WHERE ? = managers.employee_id + [1] + +The above query works well for a `get()` operation, since it limits the queries to only the tables directly involved in fetching a single instance. For instances which are already present in the session, the secondary table load is not needed. However, the above loading style is not efficient for loading large groups of objects, as it incurs separate queries for each parent row. + +One way to reduce the number of "secondary" loads of child rows is to "defer" them, using `polymorphic_fetch='deferred'`: + + {python} + mapper(Employee, employees, polymorphic_on=employees.c.type, \ + polymorphic_identity='employee', polymorphic_fetch='deferred') + mapper(Engineer, engineers, inherits=Employee, polymorphic_identity='engineer') + mapper(Manager, managers, inherits=Employee, polymorphic_identity='manager') + +The above configuration queries in the same manner as earlier, except the load of each "secondary" table occurs only when attributes referencing those columns are first referenced on the loaded instance. This style of loading is very efficient for cases where large selects of items occur, but a detailed "drill down" of extra inherited properties is less common. + +More commonly, an all-at-once load may be achieved by constructing a query which combines all three tables together. The easiest way to do this as of version 0.4.4 is to use the `with_polymorphic()` query method which will automatically join in the classes desired: + + {python} + query = session.query(Employee).with_polymorphic([Engineer, Manager]) + +Which produces a query like the following: + + {python} + query.all() + {opensql} + SELECT employees.employee_id AS employees_employee_id, engineers.employee_id AS engineers_employee_id, managers.employee_id AS managers_employee_id, employees.name AS employees_name, employees.type AS employees_type, engineers.engineer_info AS engineers_engineer_info, managers.manager_data AS managers_manager_data + FROM employees LEFT OUTER JOIN engineers ON employees.employee_id = engineers.employee_id LEFT OUTER JOIN managers ON employees.employee_id = managers.employee_id ORDER BY employees.oid + [] + +`with_polymorphic()` accepts a single class or mapper, a list of classes/mappers, or the string `'*'` to indicate all subclasses. It also accepts a second argument `selectable` which replaces the automatic join creation and instead selects directly from the selectable given. This can allow polymorphic loads from a variety of inheritance schemes including concrete tables, if the appropriate unions are constructed. + +Similar behavior as provided by `with_polymorphic()` can be configured at the mapper level so that any user-defined query is used by default in order to load instances. The `select_table` argument references an arbitrary selectable which the mapper will use for load operations (it has no impact on save operations). Any selectable can be used for this, such as a UNION of tables. For joined table inheritance, the easiest method is to use OUTER JOIN: + + {python} + join = employees.outerjoin(engineers).outerjoin(managers) + + mapper(Employee, employees, polymorphic_on=employees.c.type, \ + polymorphic_identity='employee', select_table=join) + mapper(Engineer, engineers, inherits=Employee, polymorphic_identity='engineer') + mapper(Manager, managers, inherits=Employee, polymorphic_identity='manager') + +The above mapping will produce a query similar to that of `with_polymorphic('*')` for every query of `Employee` objects. + +When `select_table` is used, `with_polymorphic()` still overrides its usage at the query level. For example, if `select_table` were configured to load from a join of multiple tables, using `with_polymorphic(Employee)` will limit the list of tables selected from to just the base table (as always, tables which don't get loaded in the first pass will be loaded on an as-needed basis). + +##### Single Table Inheritance + +Single table inheritance is where the attributes of the base class as well as all subclasses are represented within a single table. A column is present in the table for every attribute mapped to the base class and all subclasses; the columns which correspond to a single subclass are nullable. This configuration looks much like joined-table inheritance except there's only one table. In this case, a `type` column is required, as there would be no other way to discriminate between classes. The table is specified in the base mapper only; for the inheriting classes, leave their `table` parameter blank: + + {python} + employees_table = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('manager_data', String(50)), + Column('engineer_info', String(50)), + Column('type', String(20), nullable=False) + ) + + employee_mapper = mapper(Employee, employees_table, \ + polymorphic_on=employees_table.c.type, polymorphic_identity='employee') + manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer') + +Note that the mappers for the derived classes Manager and Engineer omit the specification of their associated table, as it is inherited from the employee_mapper. Omitting the table specification for derived mappers in single-table inheritance is required. + +##### Concrete Table Inheritance + +This form of inheritance maps each class to a distinct table, as below: + + {python} + employees_table = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + ) + + managers_table = Table('managers', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('manager_data', String(50)), + ) + + engineers_table = Table('engineers', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('engineer_info', String(50)), + ) + +Notice in this case there is no `type` column. If polymorphic loading is not required, there's no advantage to using `inherits` here; you just define a separate mapper for each class. + + {python} + mapper(Employee, employees_table) + mapper(Manager, managers_table) + mapper(Engineer, engineers_table) + +To load polymorphically, the `select_table` argument is currently required. In this case we must construct a UNION of all three tables. SQLAlchemy includes a helper function to create these called `polymorphic_union`, which will map all the different columns into a structure of selects with the same numbers and names of columns, and also generate a virtual `type` column for each subselect: + + {python} + pjoin = polymorphic_union({ + 'employee':employees_table, + 'manager':managers_table, + 'engineer':engineers_table + }, 'type', 'pjoin') + + employee_mapper = mapper(Employee, employees_table, select_table=pjoin, \ + polymorphic_on=pjoin.c.type, polymorphic_identity='employee') + manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, \ + concrete=True, polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, \ + concrete=True, polymorphic_identity='engineer') + +Upon select, the polymorphic union produces a query like this: + + {python} + session.query(Employee).all() + {opensql} + SELECT pjoin.type AS pjoin_type, pjoin.manager_data AS pjoin_manager_data, pjoin.employee_id AS pjoin_employee_id, + pjoin.name AS pjoin_name, pjoin.engineer_info AS pjoin_engineer_info + FROM ( + SELECT employees.employee_id AS employee_id, CAST(NULL AS VARCHAR(50)) AS manager_data, employees.name AS name, + CAST(NULL AS VARCHAR(50)) AS engineer_info, 'employee' AS type + FROM employees + UNION ALL + SELECT managers.employee_id AS employee_id, managers.manager_data AS manager_data, managers.name AS name, + CAST(NULL AS VARCHAR(50)) AS engineer_info, 'manager' AS type + FROM managers + UNION ALL + SELECT engineers.employee_id AS employee_id, CAST(NULL AS VARCHAR(50)) AS manager_data, engineers.name AS name, + engineers.engineer_info AS engineer_info, 'engineer' AS type + FROM engineers + ) AS pjoin ORDER BY pjoin.oid + [] + +##### Using Relations with Inheritance {@name=relations} + +Both joined-table and single table inheritance scenarios produce mappings which are usable in relation() functions; that is, it's possible to map a parent object to a child object which is polymorphic. Similarly, inheriting mappers can have `relation()`s of their own at any level, which are inherited to each child class. The only requirement for relations is that there is a table relationship between parent and child. An example is the following modification to the joined table inheritance example, which sets a bi-directional relationship between `Employee` and `Company`: + + {python} + employees_table = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('company_id', Integer, ForeignKey('companies.company_id')) + ) + + companies = Table('companies', metadata, + Column('company_id', Integer, primary_key=True), + Column('name', String(50))) + + class Company(object): + pass + + mapper(Company, companies, properties={ + 'employees': relation(Employee, backref='company') + }) + +SQLAlchemy has a lot of experience in this area; the optimized "outer join" approach can be used freely for parent and child relationships, eager loads are fully useable, query aliasing and other tricks are fully supported as well. + +In a concrete inheritance scenario, mapping `relation()`s is more difficult since the distinct classes do not share a table. In this case, you *can* establish a relationship from parent to child if a join condition can be constructed from parent to child, if each child table contains a foreign key to the parent: + + {python} + companies = Table('companies', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50))) + + employees_table = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) + ) + + managers_table = Table('managers', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('manager_data', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) + ) + + engineers_table = Table('engineers', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('engineer_info', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) + ) + + mapper(Employee, employees_table, select_table=pjoin, polymorphic_on=pjoin.c.type, polymorphic_identity='employee') + mapper(Manager, managers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='manager') + mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer') + mapper(Company, companies, properties={ + 'employees':relation(Employee) + }) + +Let's crank it up and try loading with an eager load: + + {python} + session.query(Company).options(eagerload('employees')).all() + {opensql} + SELECT anon_1.type AS anon_1_type, anon_1.manager_data AS anon_1_manager_data, anon_1.engineer_info AS anon_1_engineer_info, + anon_1.employee_id AS anon_1_employee_id, anon_1.name AS anon_1_name, anon_1.company_id AS anon_1_company_id, + companies.id AS companies_id, companies.name AS companies_name + FROM companies LEFT OUTER JOIN (SELECT CAST(NULL AS VARCHAR(50)) AS engineer_info, employees.employee_id AS employee_id, + CAST(NULL AS VARCHAR(50)) AS manager_data, employees.name AS name, employees.company_id AS company_id, 'employee' AS type + FROM employees UNION ALL SELECT CAST(NULL AS VARCHAR(50)) AS engineer_info, managers.employee_id AS employee_id, + managers.manager_data AS manager_data, managers.name AS name, managers.company_id AS company_id, 'manager' AS type + FROM managers UNION ALL SELECT engineers.engineer_info AS engineer_info, engineers.employee_id AS employee_id, + CAST(NULL AS VARCHAR(50)) AS manager_data, engineers.name AS name, engineers.company_id AS company_id, 'engineer' AS type + FROM engineers) AS anon_1 ON companies.id = anon_1.company_id ORDER BY companies.oid, anon_1.oid + [] + +The big limitation with concrete table inheritance is that relation()s placed on each concrete mapper do **not** propagate to child mappers. If you want to have the same relation()s set up on all concrete mappers, they must be configured manually on each. + +#### Mapping a Class against Multiple Tables {@name=joins} + +Mappers can be constructed against arbitrary relational units (called `Selectables`) as well as plain `Tables`. For example, The `join` keyword from the SQL package creates a neat selectable unit comprised of multiple tables, complete with its own composite primary key, which can be passed in to a mapper as the table. + + {python} + # a class + class AddressUser(object): + pass + + # define a Join + j = join(users_table, addresses_table) + + # map to it - the identity of an AddressUser object will be + # based on (user_id, address_id) since those are the primary keys involved + mapper(AddressUser, j, properties={ + 'user_id':[users_table.c.user_id, addresses_table.c.user_id] + }) + +A second example: + + {python} + # many-to-many join on an association table + j = join(users_table, userkeywords, + users_table.c.user_id==userkeywords.c.user_id).join(keywords, + userkeywords.c.keyword_id==keywords.c.keyword_id) + + # a class + class KeywordUser(object): + pass + + # map to it - the identity of a KeywordUser object will be + # (user_id, keyword_id) since those are the primary keys involved + mapper(KeywordUser, j, properties={ + 'user_id':[users_table.c.user_id, userkeywords.c.user_id], + 'keyword_id':[userkeywords.c.keyword_id, keywords.c.keyword_id] + }) + +In both examples above, "composite" columns were added as properties to the mappers; these are aggregations of multiple columns into one mapper property, which instructs the mapper to keep both of those columns set at the same value. + +#### Mapping a Class against Arbitrary Selects {@name=selects} + +Similar to mapping against a join, a plain select() object can be used with a mapper as well. Below, an example select which contains two aggregate functions and a group_by is mapped to a class: + + {python} + s = select([customers, + func.count(orders).label('order_count'), + func.max(orders.price).label('highest_order')], + customers.c.customer_id==orders.c.customer_id, + group_by=[c for c in customers.c] + ).alias('somealias') + class Customer(object): + pass + + mapper(Customer, s) + +Above, the "customers" table is joined against the "orders" table to produce a full row for each customer row, the total count of related rows in the "orders" table, and the highest price in the "orders" table, grouped against the full set of columns in the "customers" table. That query is then mapped against the Customer class. New instances of Customer will contain attributes for each column in the "customers" table as well as an "order_count" and "highest_order" attribute. Updates to the Customer object will only be reflected in the "customers" table and not the "orders" table. This is because the primary key columns of the "orders" table are not represented in this mapper and therefore the table is not affected by save or delete operations. + +#### Multiple Mappers for One Class {@name=multiple} + +The first mapper created for a certain class is known as that class's "primary mapper." Other mappers can be created as well, these come in two varieties. + +* **secondary mapper** + this is a mapper that must be constructed with the keyword argument `non_primary=True`, and represents a load-only mapper. Objects that are loaded with a secondary mapper will have their save operation processed by the primary mapper. It is also invalid to add new `relation()`s to a non-primary mapper. To use this mapper with the Session, specify it to the `query` method: + + example: + + {python} + # primary mapper + mapper(User, users_table) + + # make a secondary mapper to load User against a join + othermapper = mapper(User, users_table.join(someothertable), non_primary=True) + + # select + result = session.query(othermapper).select() + + The "non primary mapper" is a rarely needed feature of SQLAlchemy; in most cases, the `Query` object can produce any kind of query that's desired. It's recommended that a straight `Query` be used in place of a non-primary mapper unless the mapper approach is absolutely needed. Current use cases for the "non primary mapper" are when you want to map the class to a particular select statement or view to which additional query criterion can be added, and for when the particular mapped select statement or view is to be placed in a `relation()` of a parent mapper. + +* **entity name mapper** + this is a mapper that is a fully functioning primary mapper for a class, which is distinguished from the regular primary mapper by an `entity_name` parameter. Instances loaded with this mapper will be totally managed by this new mapper and have no connection to the original one. Most methods on `Session` include an optional `entity_name` parameter in order to specify this condition. + + example: + + {python} + # primary mapper + mapper(User, users_table) + + # make an entity name mapper that stores User objects in another table + mapper(User, alternate_users_table, entity_name='alt') + + # make two User objects + user1 = User() + user2 = User() + + # save one in in the "users" table + session.save(user1) + + # save the other in the "alternate_users_table" + session.save(user2, entity_name='alt') + + session.flush() + + # select from the alternate mapper + session.query(User, entity_name='alt').select() + + Use the "entity name" mapper when different instances of the same class are persisted in completely different tables. The "entity name" approach can also perform limited levels of horizontal partitioning as well. A more comprehensive approach to horizontal partitioning is provided by the Sharding API. + +#### Extending Mapper {@name=extending} + +Mappers can have functionality augmented or replaced at many points in its execution via the usage of the MapperExtension class. This class is just a series of "hooks" where various functionality takes place. An application can make its own MapperExtension objects, overriding only the methods it needs. Methods that are not overridden return the special value `sqlalchemy.orm.EXT_CONTINUE` to allow processing to continue to the next MapperExtension or simply proceed normally if there are no more extensions. + +API documentation for MapperExtension: [docstrings_sqlalchemy.orm_MapperExtension](rel:docstrings_sqlalchemy.orm_MapperExtension) + +To use MapperExtension, make your own subclass of it and just send it off to a mapper: + + {python} + m = mapper(User, users_table, extension=MyExtension()) + +Multiple extensions will be chained together and processed in order; they are specified as a list: + + {python} + m = mapper(User, users_table, extension=[ext1, ext2, ext3]) + +### Relation Configuration {@name=relation} + +The full list of options for the `relation()` function: + +[docstrings_sqlalchemy.orm_modfunc_relation](rel:docstrings_sqlalchemy.orm_modfunc_relation) + +#### Basic Relational Patterns {@name=patterns} + +A quick walkthrough of the basic relational patterns. + +##### One To Many {@name=onetomany} + +A one to many relationship places a foreign key in the child table referencing the parent. SQLAlchemy creates the relationship as a collection on the parent object containing instances of the child object. + + {python} + parent_table = Table('parent', metadata, + Column('id', Integer, primary_key=True)) + + child_table = Table('child', metadata, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('parent.id'))) + + class Parent(object): + pass + + class Child(object): + pass + + mapper(Parent, parent_table, properties={ + 'children':relation(Child) + }) + + mapper(Child, child_table) + +To establish a bi-directional relationship in one-to-many, where the "reverse" side is a many to one, specify the `backref` option: + + {python} + mapper(Parent, parent_table, properties={ + 'children':relation(Child, backref='parent') + }) + + mapper(Child, child_table) + +`Child` will get a `parent` attribute with many-to-one semantics. + +##### Many To One {@name=manytoone} + +Many to one places a foreign key in the parent table referencing the child. The mapping setup is identical to one-to-many, however SQLAlchemy creates the relationship as a scalar attribute on the parent object referencing a single instance of the child object. + + {python} + parent_table = Table('parent', metadata, + Column('id', Integer, primary_key=True), + Column('child_id', Integer, ForeignKey('child.id'))) + + child_table = Table('child', metadata, + Column('id', Integer, primary_key=True), + ) + + class Parent(object): + pass + + class Child(object): + pass + + mapper(Parent, parent_table, properties={ + 'child':relation(Child) + }) + + mapper(Child, child_table) + +Backref behavior is available here as well, where `backref="parents"` will place a one-to-many collection on the `Child` class. + +##### One To One {@name=onetoone} + +One To One is essentially a bi-directional relationship with a scalar attribute on both sides. To achieve this, the `uselist=False` flag indicates the placement of a scalar attribute instead of a collection on the "many" side of the relationship. To convert one-to-many into one-to-one: + + {python} + mapper(Parent, parent_table, properties={ + 'child':relation(Child, uselist=False, backref='parent') + }) + +Or to turn many-to-one into one-to-one: + + {python} + mapper(Parent, parent_table, properties={ + 'child':relation(Child, backref=backref('parent', uselist=False)) + }) + +##### Many To Many {@name=manytomany} + +Many to Many adds an association table between two classes. The association table is indicated by the `secondary` argument to `relation()`. + + {python} + left_table = Table('left', metadata, + Column('id', Integer, primary_key=True)) + + right_table = Table('right', metadata, + Column('id', Integer, primary_key=True)) + + association_table = Table('association', metadata, + Column('left_id', Integer, ForeignKey('left.id')), + Column('right_id', Integer, ForeignKey('right.id')), + ) + + mapper(Parent, left_table, properties={ + 'children':relation(Child, secondary=association_table) + }) + + mapper(Child, right_table) + +For a bi-directional relationship, both sides of the relation contain a collection by default, which can be modified on either side via the `uselist` flag to be scalar. The `backref` keyword will automatically use the same `secondary` argument for the reverse relation: + + {python} + mapper(Parent, left_table, properties={ + 'children':relation(Child, secondary=association_table, backref='parents') + }) + +##### Association Object + +The association object pattern is a variant on many-to-many: it specifically is used when your association table contains additional columns beyond those which are foreign keys to the left and right tables. Instead of using the `secondary` argument, you map a new class directly to the association table. The left side of the relation references the association object via one-to-many, and the association class references the right side via many-to-one. + + {python} + left_table = Table('left', metadata, + Column('id', Integer, primary_key=True)) + + right_table = Table('right', metadata, + Column('id', Integer, primary_key=True)) + + association_table = Table('association', metadata, + Column('left_id', Integer, ForeignKey('left.id'), primary_key=True), + Column('right_id', Integer, ForeignKey('right.id'), primary_key=True), + Column('data', String(50)) + ) + + mapper(Parent, left_table, properties={ + 'children':relation(Association) + }) + + mapper(Association, association_table, properties={ + 'child':relation(Child) + }) + + mapper(Child, right_table) + +The bi-directional version adds backrefs to both relations: + + {python} + mapper(Parent, left_table, properties={ + 'children':relation(Association, backref="parent") + }) + + mapper(Association, association_table, properties={ + 'child':relation(Child, backref="parent_assocs") + }) + + mapper(Child, right_table) + +Working with the association pattern in its direct form requires that child objects are associated with an association instance before being appended to the parent; similarly, access from parent to child goes through the association object: + + {python} + # create parent, append a child via association + p = Parent() + a = Association() + a.child = Child() + p.children.append(a) + + # iterate through child objects via association, including association + # attributes + for assoc in p.children: + print assoc.data + print assoc.child + +To enhance the association object pattern such that direct access to the `Association` object is optional, SQLAlchemy provides the [plugins_associationproxy](rel:plugins_associationproxy). + +**Important Note**: it is strongly advised that the `secondary` table argument not be combined with the Association Object pattern, unless the `relation()` which contains the `secondary` argument is marked `viewonly=True`. Otherwise, SQLAlchemy may persist conflicting data to the underlying association table since it is represented by two conflicting mappings. The Association Proxy pattern should be favored in the case where access to the underlying association data is only sometimes needed. + +#### Adjacency List Relationships {@name=selfreferential} + +The **adjacency list** pattern is a common relational pattern whereby a table contains a foreign key reference to itself. This is the most common and simple way to represent hierarchical data in flat tables. The other way is the "nested sets" model, sometimes called "modified preorder". Despite what many online articles say about modified preorder, the adjacency list model is probably the most appropriate pattern for the large majority of hierarchical storage needs, for reasons of concurrency, reduced complexity, and that modified preorder has little advantage over an application which can fully load subtrees into the application space. + +SQLAlchemy commonly refers to an adjacency list relation as a **self-referential mapper**. In this example, we'll work with a single table called `treenodes` to represent a tree structure: + + {python} + nodes = Table('treenodes', metadata, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('treenodes.id')), + Column('data', String(50)), + ) + +A graph such as the following: + + {diagram} + root --+---> child1 + +---> child2 --+--> subchild1 + | +--> subchild2 + +---> child3 + +Would be represented with data such as: + + {diagram} + id parent_id data + --- ------- ---- + 1 NULL root + 2 1 child1 + 3 1 child2 + 4 3 subchild1 + 5 3 subchild2 + 6 1 child3 + +SQLAlchemy's `mapper()` configuration for a self-referential one-to-many relationship is exactly like a "normal" one-to-many relationship. When SQLAlchemy encounters the foreign key relation from `treenodes` to `treenodes`, it assumes one-to-many unless told otherwise: + + {python} + # entity class + class Node(object): + pass + + mapper(Node, nodes, properties={ + 'children':relation(Node) + }) + +To create a many-to-one relationship from child to parent, an extra indicator of the "remote side" is added, which contains the `Column` object or objects indicating the remote side of the relation: + + {python} + mapper(Node, nodes, properties={ + 'parent':relation(Node, remote_side=[nodes.c.id]) + }) + +And the bi-directional version combines both: + + {python} + mapper(Node, nodes, properties={ + 'children':relation(Node, backref=backref('parent', remote_side=[nodes.c.id])) + }) + +There are several examples included with SQLAlchemy illustrating self-referential strategies; these include [basic_tree.py](http://www.sqlalchemy.org/trac/browser/sqlalchemy/trunk/examples/adjacencytree/basic_tree.py) and [optimized_al.py](http://www.sqlalchemy.org/trac/browser/sqlalchemy/trunk/examples/elementtree/optimized_al.py), the latter of which illustrates how to persist and search XML documents in conjunction with [ElementTree](http://effbot.org/zone/element-index.htm). + +##### Self-Referential Query Strategies {@name=query} + +Querying self-referential structures is done in the same way as any other query in SQLAlchemy, such as below, we query for any node whose `data` attrbibute stores the value `child2`: + + {python} + # get all nodes named 'child2' + sess.query(Node).filter(Node.data=='child2') + +On the subject of joins, i.e. those described in [datamapping_joins](rel:datamapping_joins), self-referential structures require the usage of aliases so that the same table can be referenced multiple times within the FROM clause of the query. Aliasing can be done either manually using the `nodes` `Table` object as a source of aliases: + + {python} + # get all nodes named 'subchild1' with a parent named 'child2' + nodealias = nodes.alias() + {sql}sess.query(Node).filter(Node.data=='subchild1').\ + filter(and_(Node.parent_id==nodealias.c.id, nodealias.c.data=='child2')).all() + SELECT treenodes.id AS treenodes_id, treenodes.parent_id AS treenodes_parent_id, treenodes.data AS treenodes_data + FROM treenodes, treenodes AS treenodes_1 + WHERE treenodes.data = ? AND treenodes.parent_id = treenodes_1.id AND treenodes_1.data = ? ORDER BY treenodes.oid + ['subchild1', 'child2'] + +or automatically, using `join()` with `aliased=True`: + + {python} + # get all nodes named 'subchild1' with a parent named 'child2' + {sql}sess.query(Node).filter(Node.data=='subchild1').\ + join('parent', aliased=True).filter(Node.data=='child2').all() + SELECT treenodes.id AS treenodes_id, treenodes.parent_id AS treenodes_parent_id, treenodes.data AS treenodes_data + FROM treenodes JOIN treenodes AS treenodes_1 ON treenodes_1.id = treenodes.parent_id + WHERE treenodes.data = ? AND treenodes_1.data = ? ORDER BY treenodes.oid + ['subchild1', 'child2'] + +To add criterion to multiple points along a longer join, use `from_joinpoint=True`: + + {python} + # get all nodes named 'subchild1' with a parent named 'child2' and a grandparent 'root' + {sql}sess.query(Node).filter(Node.data=='subchild1').\ + join('parent', aliased=True).filter(Node.data=='child2').\ + join('parent', aliased=True, from_joinpoint=True).filter(Node.data=='root').all() + SELECT treenodes.id AS treenodes_id, treenodes.parent_id AS treenodes_parent_id, treenodes.data AS treenodes_data + FROM treenodes JOIN treenodes AS treenodes_1 ON treenodes_1.id = treenodes.parent_id JOIN treenodes AS treenodes_2 ON treenodes_2.id = treenodes_1.parent_id + WHERE treenodes.data = ? AND treenodes_1.data = ? AND treenodes_2.data = ? ORDER BY treenodes.oid + ['subchild1', 'child2', 'root'] + +##### Configuring Eager Loading {@name=eagerloading} + +Eager loading of relations occurs using joins or outerjoins from parent to child table during a normal query operation, such that the parent and its child collection can be populated from a single SQL statement. SQLAlchemy's eager loading uses aliased tables in all cases when joining to related items, so it is compatible with self-referential joining. However, to use eager loading with a self-referential relation, SQLAlchemy needs to be told how many levels deep it should join; otherwise the eager load will not take place. This depth setting is configured via `join_depth`: + + {python} + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=False, join_depth=2) + }) + + {sql}session.query(Node).all() + SELECT treenodes_1.id AS treenodes_1_id, treenodes_1.parent_id AS treenodes_1_parent_id, treenodes_1.data AS treenodes_1_data, treenodes_2.id AS treenodes_2_id, treenodes_2.parent_id AS treenodes_2_parent_id, treenodes_2.data AS treenodes_2_data, treenodes.id AS treenodes_id, treenodes.parent_id AS treenodes_parent_id, treenodes.data AS treenodes_data + FROM treenodes LEFT OUTER JOIN treenodes AS treenodes_2 ON treenodes.id = treenodes_2.parent_id LEFT OUTER JOIN treenodes AS treenodes_1 ON treenodes_2.id = treenodes_1.parent_id ORDER BY treenodes.oid, treenodes_2.oid, treenodes_1.oid + [] + +#### Specifying Alternate Join Conditions to relation() {@name=customjoin} + +The `relation()` function uses the foreign key relationship between the parent and child tables to formulate the **primary join condition** between parent and child; in the case of a many-to-many relationship it also formulates the **secondary join condition**. If you are working with a `Table` which has no `ForeignKey` objects on it (which can be the case when using reflected tables with MySQL), or if the join condition cannot be expressed by a simple foreign key relationship, use the `primaryjoin` and possibly `secondaryjoin` conditions to create the appropriate relationship. + +In this example we create a relation `boston_addresses` which will only load the user addresses with a city of "Boston": + + {python} + class User(object): + pass + class Address(object): + pass + + mapper(Address, addresses_table) + mapper(User, users_table, properties={ + 'boston_addresses' : relation(Address, primaryjoin= + and_(users_table.c.user_id==addresses_table.c.user_id, + addresses_table.c.city=='Boston')) + }) + +Many to many relationships can be customized by one or both of `primaryjoin` and `secondaryjoin`, shown below with just the default many-to-many relationship explicitly set: + + {python} + class User(object): + pass + class Keyword(object): + pass + mapper(Keyword, keywords_table) + mapper(User, users_table, properties={ + 'keywords':relation(Keyword, secondary=userkeywords_table, + primaryjoin=users_table.c.user_id==userkeywords_table.c.user_id, + secondaryjoin=userkeywords_table.c.keyword_id==keywords_table.c.keyword_id + ) + }) + +##### Specifying Foreign Keys {@name=fks} + +When using `primaryjoin` and `secondaryjoin`, SQLAlchemy also needs to be aware of which columns in the relation reference the other. In most cases, a `Table` construct will have `ForeignKey` constructs which take care of this; however, in the case of reflected tables on a database that does not report FKs (like MySQL ISAM) or when using join conditions on columns that don't have foreign keys, the `relation()` needs to be told specifically which columns are "foreign" using the `foreign_keys` collection: + + {python} + mapper(Address, addresses_table) + mapper(User, users_table, properties={ + 'addresses' : relation(Address, + primaryjoin=users_table.c.user_id==addresses_table.c.user_id, + foreign_keys=[addresses_table.c.user_id]) + }) + +##### Building Query-Enabled Properties {@name=properties} + +Very ambitious custom join conditions may fail to be directly persistable, and in some cases may not even load correctly. To remove the persistence part of the equation, use the flag `viewonly=True` on the `relation()`, which establishes it as a read-only attribute (data written to the collection will be ignored on flush()). However, in extreme cases, consider using a regular Python property in conjunction with `Query` as follows: + + {python} + class User(object): + def _get_addresses(self): + return object_session(self).query(Address).with_parent(self).filter(...).all() + addresses = property(_get_addresses) + +##### Multiple Relations against the Same Parent/Child {@name=multiplejoin} + +Theres no restriction on how many times you can relate from parent to child. SQLAlchemy can usually figure out what you want, particularly if the join conditions are straightforward. Below we add a `newyork_addresses` attribute to complement the `boston_addresses` attribute: + + {python} + mapper(User, users_table, properties={ + 'boston_addresses' : relation(Address, primaryjoin= + and_(users_table.c.user_id==Address.c.user_id, + Addresses.c.city=='Boston')), + 'newyork_addresses' : relation(Address, primaryjoin= + and_(users_table.c.user_id==Address.c.user_id, + Addresses.c.city=='New York')), + }) + +#### Alternate Collection Implementations {@name=collections} + +Mapping a one-to-many or many-to-many relationship results in a collection of values accessible through an attribute on the parent instance. By default, this collection is a `list`: + + {python} + mapper(Parent, properties={ + children = relation(Child) + }) + + parent = Parent() + parent.children.append(Child()) + print parent.children[0] + +Collections are not limited to lists. Sets, mutable sequences and almost any other Python object that can act as a container can be used in place of the default list. + + {python} + # use a set + mapper(Parent, properties={ + children = relation(Child, collection_class=set) + }) + + parent = Parent() + child = Child() + parent.children.add(child) + assert child in parent.children + +##### Custom Collection Implementations {@name=custom} + +You can use your own types for collections as well. For most cases, simply inherit from `list` or `set` and add the custom behavior. + +Collections in SQLAlchemy are transparently *instrumented*. Instrumentation means that normal operations on the collection are tracked and result in changes being written to the database at flush time. Additionally, collection operations can fire *events* which indicate some secondary operation must take place. Examples of a secondary operation include saving the child item in the parent's `Session` (i.e. the `save-update` cascade), as well as synchronizing the state of a bi-directional relationship (i.e. a `backref`). + +The collections package understands the basic interface of lists, sets and dicts and will automatically apply instrumentation to those built-in types and their subclasses. Object-derived types that implement a basic collection interface are detected and instrumented via duck-typing: + + {python} + class ListLike(object): + def __init__(self): + self.data = [] + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def extend(self, items): + self.data.extend(items) + def __iter__(self): + return iter(self.data) + def foo(self): + return 'foo' + +`append`, `remove`, and `extend` are known list-like methods, and will be instrumented automatically. `__iter__` is not a mutator method and won't be instrumented, and `foo` won't be either. + +Duck-typing (i.e. guesswork) isn't rock-solid, of course, so you can be explicit about the interface you are implementing by providing an `__emulates__` class attribute: + + {python} + class SetLike(object): + __emulates__ = set + + def __init__(self): + self.data = set() + def append(self, item): + self.data.add(item) + def remove(self, item): + self.data.remove(item) + def __iter__(self): + return iter(self.data) + +This class looks list-like because of `append`, but `__emulates__` forces it to set-like. `remove` is known to be part of the set interface and will be instrumented. + +But this class won't work quite yet: a little glue is needed to adapt it for use by SQLAlchemy. The ORM needs to know which methods to use to append, remove and iterate over members of the collection. When using a type like `list` or `set`, the appropriate methods are well-known and used automatically when present. This set-like class does not provide the expected `add` method, so we must supply an explicit mapping for the ORM via a decorator. + +##### Annotating Custom Collections via Decorators {@name=decorators} + +Decorators can be used to tag the individual methods the ORM needs to manage collections. Use them when your class doesn't quite meet the regular interface for its container type, or you simply would like to use a different method to get the job done. + + {python} + from sqlalchemy.orm.collections import collection + + class SetLike(object): + __emulates__ = set + + def __init__(self): + self.data = set() + + @collection.appender + def append(self, item): + self.data.add(item) + + def remove(self, item): + self.data.remove(item) + + def __iter__(self): + return iter(self.data) + +And that's all that's needed to complete the example. SQLAlchemy will add instances via the `append` method. `remove` and `__iter__` are the default methods for sets and will be used for removing and iteration. Default methods can be changed as well: + + {python} + from sqlalchemy.orm.collections import collection + + class MyList(list): + @collection.remover + def zark(self, item): + # do something special... + + @collection.iterator + def hey_use_this_instead_for_iteration(self): + # ... + +There is no requirement to be list-, or set-like at all. Collection classes can be any shape, so long as they have the append, remove and iterate interface marked for SQLAlchemy's use. Append and remove methods will be called with a mapped entity as the single argument, and iterator methods are called with no arguments and must return an iterator. + +##### Dictionary-Based Collections {@name=dictcollections} + +A `dict` can be used as a collection, but a keying strategy is needed to map entities loaded by the ORM to key, value pairs. The [collections](rel:docstrings_sqlalchemy.orm.collections) package provides several built-in types for dictionary-based collections: + + {python} + from sqlalchemy.orm.collections import column_mapped_collection, attribute_mapped_collection, mapped_collection + + mapper(Item, items_table, properties={ + # key by column + 'notes': relation(Note, collection_class=column_mapped_collection(notes_table.c.keyword)), + # or named attribute + 'notes2': relation(Note, collection_class=attribute_mapped_collection('keyword')), + # or any callable + 'notes3': relation(Note, collection_class=mapped_collection(lambda entity: entity.a + entity.b)) + }) + + # ... + item = Item() + item.notes['color'] = Note('color', 'blue') + print item.notes['color'] + +These functions each provide a `dict` subclass with decorated `set` and `remove` methods and the keying strategy of your choice. + +The [collections.MappedCollection](rel:docstrings_sqlalchemy.orm.collections.MappedCollection) class can be used as a base class for your custom types or as a mix-in to quickly add `dict` collection support to other classes. It uses a keying function to delegate to `__setitem__` and `__delitem__`: + + {python} + from sqlalchemy.util import OrderedDict + from sqlalchemy.orm.collections import MappedCollection + + class NodeMap(OrderedDict, MappedCollection): + """Holds 'Node' objects, keyed by the 'name' attribute with insert order maintained.""" + + def __init__(self, *args, **kw): + MappedCollection.__init__(self, keyfunc=lambda node: node.name) + OrderedDict.__init__(self, *args, **kw) + +The ORM understands the `dict` interface just like lists and sets, and will automatically instrument all dict-like methods if you choose to subclass `dict` or provide dict-like collection behavior in a duck-typed class. You must decorate appender and remover methods, however- there are no compatible methods in the basic dictionary interface for SQLAlchemy to use by default. Iteration will go through `itervalues()` unless otherwise decorated. + +##### Instrumentation and Custom Types {@name=adv_collections} + +Many custom types and existing library classes can be used as a entity collection type as-is without further ado. However, it is important to note that the instrumentation process _will_ modify the type, adding decorators around methods automatically. + +The decorations are lightweight and no-op outside of relations, but they do add unneeded overhead when triggered elsewhere. When using a library class as a collection, it can be good practice to use the "trivial subclass" trick to restrict the decorations to just your usage in relations. For example: + + {python} + class MyAwesomeList(some.great.library.AwesomeList): + pass + + # ... relation(..., collection_class=MyAwesomeList) + +The ORM uses this approach for built-ins, quietly substituting a trivial subclass when a `list`, `set` or `dict` is used directly. + +The collections package provides additional decorators and support for authoring custom types. See the [package documentation](rel:docstrings_sqlalchemy.orm.collections) for more information and discussion of advanced usage and Python 2.3-compatible decoration options. + +#### Configuring Loader Strategies: Lazy Loading, Eager Loading {@name=strategies} + +In the [datamapping](rel:datamapping), we introduced the concept of **Eager Loading**. We used an `option` in conjunction with the `Query` object in order to indicate that a relation should be loaded at the same time as the parent, within a single SQL query: + + {python} + {sql}>>> jack = session.query(User).options(eagerload('addresses')).filter_by(name='jack').all() #doctest: +NORMALIZE_WHITESPACE + SELECT addresses_1.id AS addresses_1_id, addresses_1.email_address AS addresses_1_email_address, + addresses_1.user_id AS addresses_1_user_id, users.id AS users_id, users.name AS users_name, + users.fullname AS users_fullname, users.password AS users_password + FROM users LEFT OUTER JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id + WHERE users.name = ? ORDER BY users.oid, addresses_1.oid + ['jack'] + +By default, all relations are **lazy loading**. The scalar or collection attribute associated with a `relation()` contains a trigger which fires the first time the attribute is accessed, which issues a SQL call at that point: + + {python} + {sql}>>> jack.addresses + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE ? = addresses.user_id ORDER BY addresses.oid + [5] + {stop}[, ] + +The default **loader strategy** for any `relation()` is configured by the `lazy` keyword argument, which defaults to `True`. Below we set it as `False` so that the `children` relation is eager loading: + + {python} + # eager load 'children' attribute + mapper(Parent, parent_table, properties={ + 'children':relation(Child, lazy=False) + }) + +The loader strategy can be changed from lazy to eager as well as eager to lazy using the `eagerload()` and `lazyload()` query options: + + {python} + # set children to load lazily + session.query(Parent).options(lazyload('children')).all() + + # set children to load eagerly + session.query(Parent).options(eagerload('children')).all() + +To reference a relation that is deeper than one level, separate the names by periods: + + {python} + session.query(Parent).options(eagerload('foo.bar.bat')).all() + +When using dot-separated names with `eagerload()`, option applies **only** to the actual attribute named, and **not** its ancestors. For example, suppose a mapping from `A` to `B` to `C`, where the relations, named `atob` and `btoc`, are both lazy-loading. A statement like the following: + + {python} + session.query(A).options(eagerload('atob.btoc')).all() + +will load only `A` objects to start. When the `atob` attribute on each `A` is accessed, the returned `B` objects will *eagerly* load their `C` objects. + +Therefore, to modify the eager load to load both `atob` as well as `btoc`, place eagerloads for both: + + {python} + session.query(A).options(eagerload('atob'), eagerload('atob.btoc')).all() + +or more simply just use `eagerload_all()`: + + {python} + session.query(A).options(eagerload_all('atob.btoc')).all() + +There are two other loader strategies available, **dynamic loading** and **no loading**; these are described in [advdatamapping_relation_largecollections](rel:advdatamapping_relation_largecollections). + +##### Combining Eager Loads with Statement/Result Set Queries + +When full statement or result-set loads are used with `Query`, SQLAlchemy does not affect the SQL query itself, and therefore has no way of tacking on its own `LEFT [OUTER] JOIN` conditions that are normally used to eager load relationships. If the query being constructed is created in such a way that it returns rows not just from a parent table (or tables) but also returns rows from child tables, the result-set mapping can be notified as to which additional properties are contained within the result set. This is done using the `contains_eager()` query option, which specifies the name of the relationship to be eagerly loaded. + + {python} + # mapping is the users->addresses mapping + mapper(User, users_table, properties={ + 'addresses':relation(Address, addresses_table) + }) + + # define a query on USERS with an outer join to ADDRESSES + statement = users_table.outerjoin(addresses_table).select(use_labels=True) + + # construct a Query object which expects the "addresses" results + query = session.query(User).options(contains_eager('addresses')) + + # get results normally + r = query.from_statement(statement) + +If the "eager" portion of the statement is "aliased", the `alias` keyword argument to `contains_eager()` may be used to indicate it. This is a string alias name or reference to an actual `Alias` object: + + {python} + # use an alias of the addresses table + adalias = addresses_table.alias('adalias') + + # define a query on USERS with an outer join to adalias + statement = users_table.outerjoin(adalias).select(use_labels=True) + + # construct a Query object which expects the "addresses" results + query = session.query(User).options(contains_eager('addresses', alias=adalias)) + + # get results normally + {sql}r = query.from_statement(statement).all() + SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, adalias.address_id AS adalias_address_id, + adalias.user_id AS adalias_user_id, adalias.email_address AS adalias_email_address, (...other columns...) + FROM users LEFT OUTER JOIN email_addresses AS adalias ON users.user_id = adalias.user_id + +In the case that the main table itself is also aliased, the `contains_alias()` option can be used: + + {python} + # define an aliased UNION called 'ulist' + statement = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist') + + # add on an eager load of "addresses" + statement = statement.outerjoin(addresses).select(use_labels=True) + + # create query, indicating "ulist" is an alias for the main table, "addresses" property should + # be eager loaded + query = create_session().query(User).options(contains_alias('ulist'), contains_eager('addresses')) + + # results + r = query.from_statement(statement) + +#### Working with Large Collections {@name=largecollections} + +The default behavior of `relation()` is to fully load the collection of items in, as according to the loading strategy of the relation. Additionally, the Session by default only knows how to delete objects which are actually present within the session. When a parent instance is marked for deletion and flushed, the Session loads its full list of child items in so that they may either be deleted as well, or have their foreign key value set to null; this is to avoid constraint violations. For large collections of child items, there are several strategies to bypass full loading of child items both at load time as well as deletion time. + +##### Dynamic Relation Loaders {@name=dynamic} + +The most useful by far is the `dynamic_loader()` relation. This is a variant of `relation()` which returns a `Query` object in place of a collection when accessed. `filter()` criterion may be applied as well as limits and offsets, either explicitly or via array slices: + + {python} + mapper(User, users_table, properties={ + 'posts':dynamic_loader(Post) + }) + + jack = session.query(User).get(id) + + # filter Jack's blog posts + posts = jack.posts.filter(Post.c.headline=='this is a post') + + # apply array slices + posts = jack.posts[5:20] + +The dynamic relation supports limited write operations, via the `append()` and `remove()` methods. Since the read side of the dynamic relation always queries the database, changes to the underlying collection will not be visible until the data has been flushed: + + {python} + oldpost = jack.posts.filter(Post.c.headline=='old post').one() + jack.posts.remove(oldpost) + + jack.posts.append(Post('new post')) + +To place a dynamic relation on a backref, use `lazy='dynamic'`: + + {python} + mapper(Post, posts_table, properties={ + 'user':relation(User, backref=backref('posts', lazy='dynamic')) + }) + +Note that eager/lazy loading options cannot be used in conjunction dynamic relations at this time. + +##### Setting Noload {@name=noload} + +The opposite of the dynamic relation is simply "noload", specified using `lazy=None`: + + {python} + mapper(MyClass, table, properties=relation{ + 'children':relation(MyOtherClass, lazy=None) + }) + +Above, the `children` collection is fully writeable, and changes to it will be persisted to the database as well as locally available for reading at the time they are added. However when instances of `MyClass` are freshly loaded from the database, the `children` collection stays empty. + +##### Using Passive Deletes {@name=passivedelete} + +Use `passive_deletes=True` to disable child object loading on a DELETE operation, in conjunction with "ON DELETE (CASCADE|SET NULL)" on your database to automatically cascade deletes to child objects. Note that "ON DELETE" is not supported on SQLite, and requires `InnoDB` tables when using MySQL: + + {python} + mytable = Table('mytable', meta, + Column('id', Integer, primary_key=True), + ) + + myothertable = Table('myothertable', meta, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer), + ForeignKeyConstraint(['parent_id'],['mytable.id'], ondelete="CASCADE"), + ) + + mmapper(MyOtherClass, myothertable) + + mapper(MyClass, mytable, properties={ + 'children':relation(MyOtherClass, cascade="all, delete-orphan", passive_deletes=True) + }) + +When `passive_deletes` is applied, the `children` relation will not be loaded into memory when an instance of `MyClass` is marked for deletion. The `cascade="all, delete-orphan"` *will* take effect for instances of `MyOtherClass` which are currently present in the session; however for instances of `MyOtherClass` which are not loaded, SQLAlchemy assumes that "ON DELETE CASCADE" rules will ensure that those rows are deleted by the database and that no foreign key violation will occur. + +#### Mutable Primary Keys / Update Cascades {@name=mutablepks} + +As of SQLAlchemy 0.4.2, the primary key attributes of an instance can be changed freely, and will be persisted upon flush. When the primary key of an entity changes, related items which reference the primary key must also be updated as well. For databases which enforce referential integrity, it's required to use the database's ON UPDATE CASCADE functionality in order to propagate primary key changes. For those which don't, the `passive_cascades` flag can be set to `False` which instructs SQLAlchemy to issue UPDATE statements individually. The `passive_cascades` flag can also be `False` in conjunction with ON UPDATE CASCADE functionality, although in that case it issues UPDATE statements unnecessarily. + +A typical mutable primary key setup might look like: + + {python} + users = Table('users', metadata, + Column('username', String(50), primary_key=True), + Column('fullname', String(100))) + + addresses = Table('addresses', metadata, + Column('email', String(50), primary_key=True), + Column('username', String(50), ForeignKey('users.username', onupdate="cascade"))) + + class User(object): + pass + class Address(object): + pass + + mapper(User, users, properties={ + 'addresses':relation(Address, passive_updates=False) + }) + mapper(Address, addresses) + +passive_updates is set to `True` by default. Foreign key references to non-primary key columns are supported as well. + diff --git a/doc/build/content/metadata.txt b/doc/build/content/metadata.txt index ff854bd39c..151a7ce5ed 100644 --- a/doc/build/content/metadata.txt +++ b/doc/build/content/metadata.txt @@ -121,7 +121,7 @@ And `Table` provides an interface to the table's properties as well as that of i A `MetaData` object can be associated with an `Engine` or an individual `Connection`; this process is called **binding**. The term used to describe "an engine or a connection" is often referred to as a **connectable**. Binding allows the `MetaData` and the elements which it contains to perform operations against the database directly, using the connection resources to which it's bound. Common operations which are made more convenient through binding include being able to generate SQL constructs which know how to execute themselves, creating `Table` objects which query the database for their column and constraint information, and issuing CREATE or DROP statements. -To bind `MetaData` to an `Engine`, use the `connect()` method: +To bind `MetaData` to an `Engine`, use the `bind` attribute: {python} engine = create_engine('sqlite://', **kwargs) @@ -156,7 +156,7 @@ Note that the feature of binding engines is **completely optional**. All of the #### Reflecting Tables -A `Table` object can be created without specifying any of its contained attributes, using the argument `autoload=True` in conjunction with the table's name and possibly its schema (if not the databases "default" schema). This will issue the appropriate queries to the database in order to locate all properties of the table required for SQLAlchemy to use it effectively, including its column names and datatypes, foreign and primary key constraints, and in some cases its default-value generating attributes. To use `autoload=True`, the table's `MetaData` object need be bound to an `Engine` or `Connection`, or alternatively the `autoload_with=` argument can be passed. Below we illustrate autoloading a table and then iterating through the names of its columns: +A `Table` object can be created without specifying any of its contained attributes, using the argument `autoload=True` in conjunction with the table's name and possibly its schema (if not the databases "default" schema). (You can also specify a list or set of column names to autoload as the kwarg include_columns, if you only want to load a subset of the columns in the actual database.) This will issue the appropriate queries to the database in order to locate all properties of the table required for SQLAlchemy to use it effectively, including its column names and datatypes, foreign and primary key constraints, and in some cases its default-value generating attributes. To use `autoload=True`, the table's `MetaData` object need be bound to an `Engine` or `Connection`, or alternatively the `autoload_with=` argument can be passed. Below we illustrate autoloading a table and then iterating through the names of its columns: {python} >>> messages = Table('messages', meta, autoload=True) @@ -170,12 +170,12 @@ Note that if a reflected table has a foreign key referencing another table, the >>> 'shopping_carts' in meta.tables: True -To get direct access to 'shopping_carts', simply instantiate it via the `Table` constructor. `Table` uses a special contructor that will return the already created `Table` instance if its already present: +To get direct access to 'shopping_carts', simply instantiate it via the `Table` constructor. `Table` uses a special contructor that will return the already created `Table` instance if it's already present: {python} shopping_carts = Table('shopping_carts', meta) -Of course, its a good idea to use `autoload=True` with the above table regardless. This is so that if it hadn't been loaded already, the operation will load the table. The autoload operation only occurs for the table if it hasn't already been loaded; once loaded, new calls to `Table` will not re-issue any reflection queries. +Of course, it's a good idea to use `autoload=True` with the above table regardless. This is so that if it hadn't been loaded already, the operation will load the table. The autoload operation only occurs for the table if it hasn't already been loaded; once loaded, new calls to `Table` will not re-issue any reflection queries. ##### Overriding Reflected Columns {@name=overriding} @@ -295,16 +295,16 @@ Entire groups of Tables can be created and dropped directly from the `MetaData` pref_value VARCHAR(100) ) -### Column Defaults and OnUpdates {@name=defaults} +### Column Insert/Update Defaults {@name=defaults} -SQLAlchemy includes flexible constructs in which to create default values for columns upon the insertion of rows, as well as upon update. These defaults can take several forms: a constant, a Python callable to be pre-executed before the SQL is executed, a SQL expression or function to be pre-executed before the SQL is executed, a pre-executed Sequence (for databases that support sequences), or a "passive" default, which is a default function triggered by the database itself upon insert, the value of which can then be post-fetched by the engine, provided the row provides a primary key in which to call upon. +SQLAlchemy includes several constructs which provide default values provided during INSERT and UPDATE statements. The defaults may be provided as Python constants, Python functions, or SQL expressions, and the SQL expressions themselves may be "pre-executed", executed inline within the insert/update statement itself, or can be created as a SQL level "default" placed on the table definition itself. A "default" value by definition is only invoked if no explicit value is passed into the INSERT or UPDATE statement. -#### Pre-Executed Insert Defaults {@name=oninsert} +#### Pre-Executed Python Functions {@name=preexecute_functions} -A basic default is most easily specified by the "default" keyword argument to Column. This defines a value, function, or SQL expression that will be pre-executed to produce the new value, before the row is inserted: +The "default" keyword argument on Column can reference a Python value or callable which is invoked at the time of an insert: {python} - # a function to create primary key ids + # a function which counts upwards i = 0 def mydefault(): global i @@ -318,8 +318,22 @@ A basic default is most easily specified by the "default" keyword argument to Co # a scalar default Column('key', String(10), default="default") ) - -The "default" keyword can also take SQL expressions, including select statements or direct function calls: + +Similarly, the "onupdate" keyword does the same thing for update statements: + + {python} + import datetime + + t = Table("mytable", meta, + Column('id', Integer, primary_key=True), + + # define 'last_updated' to be populated with datetime.now() + Column('last_updated', DateTime, onupdate=datetime.now), + ) + +#### Pre-executed and Inline SQL Expressions {@name=sqlexpression} + +The "default" and "onupdate" keywords may also be passed SQL expressions, including select statements or direct function calls: {python} t = Table("mytable", meta, @@ -330,38 +344,30 @@ The "default" keyword can also take SQL expressions, including select statements # define 'key' to pull its default from the 'keyvalues' table Column('key', String(20), default=keyvalues.select(keyvalues.c.type='type1', limit=1)) + + # define 'last_modified' to use the current_timestamp SQL function on update + Column('last_modified', DateTime, onupdate=func.current_timestamp()) ) - -The "default" keyword argument is shorthand for using a ColumnDefault object in a column definition. This syntax is optional, but is required for other types of defaults, futher described below: - {python} - Column('mycolumn', String(30), ColumnDefault(func.get_data())) +The above SQL functions are usually executed "inline" with the INSERT or UPDATE statement being executed. In some cases, the function is "pre-executed" and its result pre-fetched explicitly. This happens under the following circumstances: -#### Pre-Executed OnUpdate Defaults {@name=onupdate} +* the column is a primary key column -Similar to an on-insert default is an on-update default, which is most easily specified by the "onupdate" keyword to Column, which also can be a constant, plain Python function or SQL expression: +* the database dialect does not support a usable `cursor.lastrowid` accessor (or equivalent); this currently includes Postgres, Oracle, and Firebird. - {python} - t = Table("mytable", meta, - Column('id', Integer, primary_key=True), - - # define 'last_updated' to be populated with current_timestamp (the ANSI-SQL version of now()) - Column('last_updated', DateTime, onupdate=func.current_timestamp()), - ) - +* the statement is a single execution, i.e. only supplies one set of parameters and doesn't use "executemany" behavior -To use an explicit ColumnDefault object to specify an on-update, use the "for_update" keyword argument: +* the `inline=True` flag is not set on the `Insert()` or `Update()` construct. - {python} - Column('mycolumn', String(30), ColumnDefault(func.get_data(), for_update=True)) - -#### Inline Default Execution: PassiveDefault {@name=passive} +For a statement execution which is not an executemany, the returned `ResultProxy` will contain a collection accessible via `result.postfetch_cols()` which contains a list of all `Column` objects which had an inline-executed default. Similarly, all parameters which were bound to the statement, including all Python and SQL expressions which were pre-executed, are present in the `last_inserted_params()` or `last_updated_params()` collections on `ResultProxy`. The `last_inserted_ids()` collection contains a list of primary key values for the row inserted. -A PassiveDefault indicates an column default that is executed upon INSERT by the database. This construct is used to specify a SQL function that will be specified as "DEFAULT" when creating tables. +#### DDL-Level Defaults {@name=passive} + +A variant on a SQL expression default is the `PassiveDefault`, which gets placed in the CREATE TABLE statement during a `create()` operation: {python} t = Table('test', meta, - Column('mycolumn', DateTime, PassiveDefault("sysdate")) + Column('mycolumn', DateTime, PassiveDefault(text("sysdate"))) ) A create call for the above table will produce: @@ -371,31 +377,7 @@ A create call for the above table will produce: mycolumn datetime default sysdate ) -PassiveDefault also sends a message to the `Engine` that data is available after an insert. The object-relational mapper system uses this information to post-fetch rows after the insert, so that instances can be refreshed with the new data. Below is a simplified version: - - {python} - # table with passive defaults - mytable = Table('mytable', engine, - Column('my_id', Integer, primary_key=True), - - # an on-insert database-side default - Column('data1', Integer, PassiveDefault("d1_func()")), - ) - # insert a row - r = mytable.insert().execute(name='fred') - - # check the result: were there defaults fired off on that row ? - if r.lastrow_has_defaults(): - # postfetch the row based on primary key. - # this only works for a table with primary key columns defined - primary_key = r.last_inserted_ids() - row = table.select(table.c.id == primary_key[0]) - -When Tables are reflected from the database using `autoload=True`, any DEFAULT values set on the columns will be reflected in the Table object as PassiveDefault instances. - -##### The Catch: Postgres Primary Key Defaults always Pre-Execute {@name=postgres} - -Current Postgres support does not rely upon OID's to determine the identity of a row. This is because the usage of OIDs has been deprecated with Postgres and they are disabled by default for table creates as of PG version 8. Pyscopg2's "cursor.lastrowid" function only returns OIDs. Therefore, when inserting a new row which has passive defaults set on the primary key columns, the default function is still pre-executed since SQLAlchemy would otherwise have no way of retrieving the row just inserted. +The behavior of `PassiveDefault` is similar to that of a regular SQL default; if it's placed on a primary key column for a database which doesn't have a way to "postfetch" the ID, and the statement is not "inlined", the SQL expression is pre-executed; otherwise, SQLAlchemy lets the default fire off on the database side normally. #### Defining Sequences {@name=sequences} @@ -408,11 +390,17 @@ A table with a sequence looks like: Column("createdate", DateTime()) ) -The Sequence is used with Postgres or Oracle to indicate the name of a database sequence that will be used to create default values for a column. When a table with a Sequence on a column is created in the database by SQLAlchemy, the database sequence object is also created. Similarly, the database sequence is dropped when the table is dropped. Sequences are typically used with primary key columns. When using Postgres, if an integer primary key column defines no explicit Sequence or other default method, SQLAlchemy will create the column with the SERIAL keyword, and will pre-execute a sequence named "tablename_columnname_seq" in order to retrieve new primary key values, if they were not otherwise explicitly stated. Oracle, which has no "auto-increment" keyword, requires that a Sequence be specified for a table if automatic primary key generation is desired. +The `Sequence` object works a lot like the `default` keyword on `Column`, except that it only takes effect on a database which supports sequences. When used with a database that does not support sequences, the `Sequence` object has no effect; therefore it's safe to place on a table which is used against multiple database backends. The same rules for pre- and inline execution apply. -A Sequence object can be defined on a Table that is then also used with a non-sequence-supporting database. In that case, the Sequence object is simply ignored. Note that a Sequence object is **entirely optional for all databases except Oracle**, as other databases offer options for auto-creating primary key values, such as AUTOINCREMENT, SERIAL, etc. SQLAlchemy will use these default methods for creating primary key values if no Sequence is present on the table metadata. +When the `Sequence` is associated with a table, CREATE and DROP statements issued for that table will also issue CREATE/DROP for the sequence object as well, thus "bundling" the sequence object with its parent table. -A sequence can also be specified with `optional=True` which indicates the Sequence should only be used on a database that requires an explicit sequence, and not those that supply some other method of providing integer values. At the moment, it essentially means "use this sequence only with Oracle and not Postgres". +The flag `optional=True` on `Sequence` will produce a sequence that is only used on databases which have no "autoincrementing" capability. For example, Postgres supports primary key generation using the SERIAL keyword, whereas Oracle has no such capability. Therefore, a `Sequence` placed on a primary key column with `optional=True` will only be used with an Oracle backend but not Postgres. + +A sequence can also be executed standalone, using an `Engine` or `Connection`, returning its next value in a database-independent fashion: + + {python} + seq = Sequence('some_sequence') + nextid = connection.execute(seq) ### Defining Constraints and Indexes {@name=constraints} @@ -438,20 +426,20 @@ Unique constraints can be created anonymously on a single column using the `uniq Check constraints can be named or unnamed and can be created at the Column or Table level, using the `CheckConstraint` construct. The text of the check constraint is passed directly through to the database, so there is limited "database independent" behavior. Column level check constraints generally should only refer to the column to which they are placed, while table level constraints can refer to any columns in the table. -Note that some databases do not actively support check constraints such as MySQL and sqlite. +Note that some databases do not actively support check constraints such as MySQL and SQLite. {python} meta = MetaData() mytable = Table('mytable', meta, # per-column CHECK constraint - Column('col1', Integer, CheckConstraint('col1>5')), + Column('col1', Integer, CheckConstraint('col1>5')), Column('col2', Integer), Column('col3', Integer), # table level CHECK constraint. 'name' is optional. - CheckConstraint('col2 > col3 + 5', name='check1') + CheckConstraint('col2 > col3 + 5', name='check1') ) #### Indexes diff --git a/doc/build/content/ormtutorial.txt b/doc/build/content/ormtutorial.txt new file mode 100644 index 0000000000..85ffa0ff93 --- /dev/null +++ b/doc/build/content/ormtutorial.txt @@ -0,0 +1,1143 @@ +[alpha_api]: javascript:alphaApi() +[alpha_implementation]: javascript:alphaImplementation() + +Object Relational Tutorial {@name=datamapping} +============ + +In this tutorial we will cover a basic SQLAlchemy object-relational mapping scenario, where we store and retrieve Python objects from a database representation. The database schema will begin with one table, and will later develop into several. The tutorial is in doctest format, meaning each `>>>` line represents something you can type at a Python command prompt, and the following text represents the expected return value. The tutorial has no prerequisites. + +## Version Check + +A quick check to verify that we are on at least **version 0.4** of SQLAlchemy: + + {python} + >>> import sqlalchemy + >>> sqlalchemy.__version__ # doctest:+SKIP + 0.4.0 + +## Connecting + +For this tutorial we will use an in-memory-only SQLite database. This is an easy way to test things without needing to have an actual database defined anywhere. To connect we use `create_engine()`: + + {python} + >>> from sqlalchemy import create_engine + >>> engine = create_engine('sqlite:///:memory:', echo=True) + +The `echo` flag is a shortcut to setting up SQLAlchemy logging, which is accomplished via Python's standard `logging` module. With it enabled, we'll see all the generated SQL produced. If you are working through this tutorial and want less output generated, set it to `False`. This tutorial will format the SQL behind a popup window so it doesn't get in our way; just click the "SQL" links to see whats being generated. + +## Define and Create a Table {@name=tables} + +Next we want to tell SQLAlchemy about our tables. We will start with just a single table called `users`, which will store records for the end-users using our application (lets assume it's a website). We define our tables all within a catalog called `MetaData`, using the `Table` construct, which resembles regular SQL CREATE TABLE syntax: + + {python} + >>> from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey + >>> metadata = MetaData() + >>> users_table = Table('users', metadata, + ... Column('id', Integer, primary_key=True), + ... Column('name', String(40)), + ... Column('fullname', String(100)), + ... Column('password', String(15)) + ... ) + +All about how to define `Table` objects, as well as how to create them from an existing database automatically, is described in [metadata](rel:metadata). + +Next, to tell the `MetaData` we'd actually like to create our `users_table` for real inside the SQLite database, we use `create_all()`, passing it the `engine` instance which points to our database. This will check for the presence of a table first before creating, so it's safe to call multiple times: + + {python} + {sql}>>> metadata.create_all(engine) # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE + PRAGMA table_info("users") + {} + CREATE TABLE users ( + id INTEGER NOT NULL, + name VARCHAR(40), + fullname VARCHAR(100), + password VARCHAR(15), + PRIMARY KEY (id) + ) + {} + COMMIT + +So now our database is created, our initial schema is present, and our SQLAlchemy application knows all about the tables and columns in the database; this information is to be re-used by the Object Relational Mapper, as we'll see now. + +## Define a Python Class to be Mapped {@name=mapping} + +So lets create a rudimentary `User` object to be mapped in the database. This object will for starters have three attributes, `name`, `fullname` and `password`. It only need subclass Python's built-in `object` class (i.e. it's a new style class). We will give it a constructor so that it may conveniently be instantiated with its attributes at once, as well as a `__repr__` method so that we can get a nice string representation of it: + + {python} + >>> class User(object): + ... def __init__(self, name, fullname, password): + ... self.name = name + ... self.fullname = fullname + ... self.password = password + ... + ... def __repr__(self): + ... return "" % (self.name, self.fullname, self.password) + +## Setting up the Mapping + +With our `users_table` and `User` class, we now want to map the two together. That's where the SQLAlchemy ORM package comes in. We'll use the `mapper` function to create a **mapping** between `users_table` and `User`: + + {python} + >>> from sqlalchemy.orm import mapper + >>> mapper(User, users_table) # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE + + +The `mapper()` function creates a new `Mapper` object and stores it away for future reference. It also **instruments** the attributes on our `User` class, corresponding to the `users_table` table. The `id`, `name`, `fullname`, and `password` columns in our `users_table` are now instrumented upon our `User` class, meaning it will keep track of all changes to these attributes, and can save and load their values to/from the database. Lets create our first user, 'Ed Jones', and ensure that the object has all three of these attributes: + + {python} + >>> ed_user = User('ed', 'Ed Jones', 'edspassword') + >>> ed_user.name + 'ed' + >>> ed_user.password + 'edspassword' + >>> str(ed_user.id) + 'None' + +What was that last `id` attribute? That was placed there by the `Mapper`, to track the value of the `id` column in the `users_table`. Since our `User` doesn't exist in the database, its id is `None`. When we save the object, it will get populated automatically with its new id. + +## Too Verbose ? There are alternatives + +The full set of steps to map a class, which are to define a `Table`, define a class, and then define a `mapper()`, are fairly verbose and for simple cases may appear overly disjoint. Most popular object relational products use the so-called "active record" approach, where the table definition and its class mapping are all defined at once. With SQLAlchemy, there are two excellent alternatives to its usual configuration which provide this approach: + + * [Elixir](http://elixir.ematia.de/) is a "sister" product to SQLAlchemy, which is a full "declarative" layer built on top of SQLAlchemy. It has existed almost as long as SA itself and defines a rich featureset on top of SA's normal configuration, adding many new capabilities such as plugins, automatic generation of table and column names based on configurations, and an intuitive system of defining relations. + * [declarative](rel:plugins_declarative) is a so-called "micro-declarative" plugin included with SQLAlchemy 0.4.4 and above. In contrast to Elixir, it maintains the use of the same configurational constructs outlined in this tutorial, except it allows the `Column`, `relation()`, and other constructs to be defined "inline" with the mapped class itself, so that explicit calls to `Table` and `mapper()` are not needed in most cases. + +With either declarative layer it's a good idea to be familiar with SQLAlchemy's "base" configurational style in any case. But now that we have our configuration started, we're ready to look at how to build sessions and query the database; this process is the same regardless of configurational style. + +## Creating a Session + +We're now ready to start talking to the database. The ORM's "handle" to the database is the `Session`. When we first set up the application, at the same level as our `create_engine()` statement, we define a second object called `Session` (or whatever you want to call it, `create_session`, etc.) which is configured by the `sessionmaker()` function. This function is configurational and need only be called once. + + {python} + >>> from sqlalchemy.orm import sessionmaker + >>> Session = sessionmaker(bind=engine, autoflush=True, transactional=True) + +In the case where your application does not yet have an `Engine` when you define your module-level objects, just set it up like this: + + {python} + >>> Session = sessionmaker(autoflush=True, transactional=True) + +Later, when you create your engine with `create_engine()`, connect it to the `Session` using `configure()`: + + {python} + >>> Session.configure(bind=engine) # once engine is available + +This `Session` class will create new `Session` objects which are bound to our database and have the transactional characteristics we've configured. Whenever you need to have a conversation with the database, you instantiate a `Session`: + + {python} + >>> session = Session() + +The above `Session` is associated with our SQLite `engine`, but it hasn't opened any connections yet. When it's first used, it retrieves a connection from a pool of connections maintained by the `engine`, and holds onto it until we commit all changes and/or close the session object. Because we configured `transactional=True`, there's also a transaction in progress (one notable exception to this is MySQL, when you use its default table style of MyISAM). There's options available to modify this behavior but we'll go with this straightforward version to start. + +## Saving Objects + +So saving our `User` is as easy as issuing `save()`: + + {python} + >>> session.save(ed_user) + +But you'll notice nothing has happened yet. Well, lets pretend something did, and try to query for our user. This is done using the `query()` method on `Session`. We create a new query representing the set of all `User` objects first. Then we narrow the results by "filtering" down to the user we want; that is, the user whose `name` attribute is `"ed"`. Finally we call `first()` which tells `Query`, "we'd like the first result in this list". + + {python} + {sql}>>> session.query(User).filter_by(name='ed').first() # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE + BEGIN + INSERT INTO users (name, fullname, password) VALUES (?, ?, ?) + ['ed', 'Ed Jones', 'edspassword'] + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name = ? ORDER BY users.oid + LIMIT 1 OFFSET 0 + ['ed'] + {stop} + +And we get back our new user. If you view the generated SQL, you'll see that the `Session` issued an `INSERT` statement before querying. The `Session` stores whatever you put into it in memory, and at certain points it issues a **flush**, which issues SQL to the database to store all pending new objects and changes to existing objects. You can manually invoke the flush operation using `flush()`; however when the `Session` is configured to `autoflush`, it's usually not needed. + +OK, let's do some more operations. We'll create and save three more users: + + {python} + >>> session.save(User('wendy', 'Wendy Williams', 'foobar')) + >>> session.save(User('mary', 'Mary Contrary', 'xxg527')) + >>> session.save(User('fred', 'Fred Flinstone', 'blah')) + +Also, Ed has already decided his password isn't too secure, so lets change it: + + {python} + >>> ed_user.password = 'f8s7ccs' + +Then we'll permanently store everything thats been changed and added to the database. We do this via `commit()`: + + {python} + {sql}>>> session.commit() + UPDATE users SET password=? WHERE users.id = ? + ['f8s7ccs', 1] + INSERT INTO users (name, fullname, password) VALUES (?, ?, ?) + ['wendy', 'Wendy Williams', 'foobar'] + INSERT INTO users (name, fullname, password) VALUES (?, ?, ?) + ['mary', 'Mary Contrary', 'xxg527'] + INSERT INTO users (name, fullname, password) VALUES (?, ?, ?) + ['fred', 'Fred Flinstone', 'blah'] + COMMIT + +`commit()` flushes whatever remaining changes remain to the database, and commits the transaction. The connection resources referenced by the session are now returned to the connection pool. Subsequent operations with this session will occur in a **new** transaction, which will again re-acquire connection resources when first needed. + +If we look at Ed's `id` attribute, which earlier was `None`, it now has a value: + + {python} + >>> ed_user.id + 1 + +After each `INSERT` operation, the `Session` assigns all newly generated ids and column defaults to the mapped object instance. For column defaults which are database-generated and are not part of the table's primary key, they'll be loaded when you first reference the attribute on the instance. + +One crucial thing to note about the `Session` is that each object instance is cached within the Session, based on its primary key identifier. The reason for this cache is not as much for performance as it is for maintaining an **identity map** of instances. This map guarantees that whenever you work with a particular `User` object in a session, **you always get the same instance back**. As below, reloading Ed gives us the same instance back: + + {python} + {sql}>>> ed_user is session.query(User).filter_by(name='ed').one() # doctest: +NORMALIZE_WHITESPACE + BEGIN + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name = ? ORDER BY users.oid + LIMIT 2 OFFSET 0 + ['ed'] + {stop}True + +The `get()` method, which queries based on primary key, will not issue any SQL to the database if the given key is already present: + + {python} + >>> ed_user is session.query(User).get(ed_user.id) + True + +## Querying + +A whirlwind tour through querying. + +A `Query` is created from the `Session`, relative to a particular class we wish to load. + + {python} + >>> query = session.query(User) + +Once we have a query, we can start loading objects. The Query object, when first created, represents all the instances of its main class. You can iterate through it directly: + + {python} + {sql}>>> for user in session.query(User): + ... print user.name + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users ORDER BY users.oid + [] + {stop}ed + wendy + mary + fred + +...and the SQL will be issued at the point where the query is evaluated as a list. If you apply array slices before iterating, LIMIT and OFFSET are applied to the query: + + {python} + {sql}>>> for u in session.query(User)[1:3]: #doctest: +NORMALIZE_WHITESPACE + ... print u + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users ORDER BY users.oid + LIMIT 2 OFFSET 1 + [] + {stop} + + +Narrowing the results down is accomplished either with `filter_by()`, which uses keyword arguments: + + {python} + {sql}>>> for user in session.query(User).filter_by(name='ed', fullname='Ed Jones'): + ... print user + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.fullname = ? AND users.name = ? ORDER BY users.oid + ['Ed Jones', 'ed'] + {stop} + +...or `filter()`, which uses SQL expression language constructs. These allow you to use regular Python operators with the class-level attributes on your mapped class: + + {python} + {sql}>>> for user in session.query(User).filter(User.name=='ed'): + ... print user + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name = ? ORDER BY users.oid + ['ed'] + {stop} + +You can also use the `Column` constructs attached to the `users_table` object to construct SQL expressions: + + {python} + {sql}>>> for user in session.query(User).filter(users_table.c.name=='ed'): + ... print user + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name = ? ORDER BY users.oid + ['ed'] + {stop} + +Most common SQL operators are available, such as `LIKE`: + + {python} + {sql}>>> session.query(User).filter(User.name.like('%ed'))[1] # doctest: +NORMALIZE_WHITESPACE + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name LIKE ? ORDER BY users.oid + LIMIT 1 OFFSET 1 + ['%ed'] + {stop} + +Note above our array index of `1` placed the appropriate LIMIT/OFFSET and returned a scalar result immediately. + +The `all()`, `one()`, and `first()` methods immediately issue SQL without using an iterative context or array index. `all()` returns a list: + + {python} + >>> query = session.query(User).filter(User.name.like('%ed')) + + {sql}>>> query.all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name LIKE ? ORDER BY users.oid + ['%ed'] + {stop}[, ] + +`first()` applies a limit of one and returns the first result as a scalar: + + {python} + {sql}>>> query.first() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name LIKE ? ORDER BY users.oid + LIMIT 1 OFFSET 0 + ['%ed'] + {stop} + +and `one()`, applies a limit of *two*, and if not exactly one row returned (no more, no less), raises an error: + + {python} + {sql}>>> try: + ... user = query.one() + ... except Exception, e: + ... print e + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name LIKE ? ORDER BY users.oid + LIMIT 2 OFFSET 0 + ['%ed'] + {stop}Multiple rows returned for one() + +All `Query` methods that don't return a result instead return a new `Query` object, with modifications applied. Therefore you can call many query methods successively to build up the criterion you want: + + {python} + {sql}>>> session.query(User).filter(User.id<2).filter_by(name='ed').\ + ... filter(User.fullname=='Ed Jones').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.id < ? AND users.name = ? AND users.fullname = ? ORDER BY users.oid + [2, 'ed', 'Ed Jones'] + {stop}[] + +If you need to use other conjunctions besides `AND`, all SQL conjunctions are available explicitly within expressions, such as `and_()` and `or_()`, when using `filter()`: + + {python} + >>> from sqlalchemy import and_, or_ + + {sql}>>> session.query(User).filter( + ... and_(User.id<224, or_(User.name=='ed', User.name=='wendy')) + ... ).all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.id < ? AND (users.name = ? OR users.name = ?) ORDER BY users.oid + [224, 'ed', 'wendy'] + {stop}[, ] + +You also have full ability to use literal strings to construct SQL. For a single criterion, use a string with `filter()`: + + {python} + {sql}>>> for user in session.query(User).filter("id<224").all(): + ... print user.name + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE id<224 ORDER BY users.oid + [] + {stop}ed + wendy + mary + fred + +Bind parameters can be specified with string-based SQL, using a colon. To specify the values, use the `params()` method: + + {python} + {sql}>>> session.query(User).filter("id<:value and name=:name").\ + ... params(value=224, name='fred').one() # doctest: +NORMALIZE_WHITESPACE + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE id + +Note that when we use constructed SQL expressions, bind parameters are generated for us automatically; we don't need to worry about them. + +To use an entirely string-based statement, using `from_statement()`; just ensure that the columns clause of the statement contains the column names normally used by the mapper (below illustrated using an asterisk): + + {python} + {sql}>>> session.query(User).from_statement("SELECT * FROM users where name=:name").params(name='ed').all() + SELECT * FROM users where name=? + ['ed'] + {stop}[] + +`from_statement()` can also accomodate full `select()` constructs. These are described in the [sql](rel:sql): + + {python} + >>> from sqlalchemy import select, func + + {sql}>>> session.query(User).from_statement( + ... select( + ... [users_table], + ... select([func.max(users_table.c.name)]).label('maxuser')==users_table.c.name) + ... ).all() # doctest: +NORMALIZE_WHITESPACE + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE (SELECT max(users.name) AS max_1 + FROM users) = users.name + [] + {stop}[] + +There's also a way to combine scalar results with objects, using `add_column()`. This is often used for functions and aggregates. When `add_column()` (or its cousin `add_entity()`, described later) is used, tuples are returned: + + {python} + {sql}>>> for r in session.query(User).\ + ... add_column(select([func.max(users_table.c.name)]).label('maxuser')): + ... print r # doctest: +NORMALIZE_WHITESPACE + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password, (SELECT max(users.name) AS max_1 + FROM users) AS maxuser + FROM users ORDER BY users.oid + [] + {stop}(, u'wendy') + (, u'wendy') + (, u'wendy') + (, u'wendy') + +## Building a One-to-Many Relation {@name=onetomany} + +We've spent a lot of time dealing with just one class, and one table. Let's now look at how SQLAlchemy deals with two tables, which have a relationship to each other. Let's say that the users in our system also can store any number of email addresses associated with their username. This implies a basic one to many association from the `users_table` to a new table which stores email addresses, which we will call `addresses`. We will also create a relationship between this new table to the users table, using a `ForeignKey`: + + {python} + >>> from sqlalchemy import ForeignKey + + >>> addresses_table = Table('addresses', metadata, + ... Column('id', Integer, primary_key=True), + ... Column('email_address', String(100), nullable=False), + ... Column('user_id', Integer, ForeignKey('users.id'))) + +Another call to `create_all()` will skip over our `users` table and build just the new `addresses` table: + + {python} + {sql}>>> metadata.create_all(engine) # doctest: +NORMALIZE_WHITESPACE + PRAGMA table_info("users") + {} + PRAGMA table_info("addresses") + {} + CREATE TABLE addresses ( + id INTEGER NOT NULL, + email_address VARCHAR(100) NOT NULL, + user_id INTEGER, + PRIMARY KEY (id), + FOREIGN KEY(user_id) REFERENCES users (id) + ) + {} + COMMIT + +For our ORM setup, we're going to start all over again. We will first close out our `Session` and clear all `Mapper` objects: + + {python} + >>> from sqlalchemy.orm import clear_mappers + >>> session.close() + >>> clear_mappers() + +Our `User` class, still around, reverts to being just a plain old class. Lets create an `Address` class to represent a user's email address: + + {python} + >>> class Address(object): + ... def __init__(self, email_address): + ... self.email_address = email_address + ... + ... def __repr__(self): + ... return "" % self.email_address + +Now comes the fun part. We define a mapper for each class, and associate them using a function called `relation()`. We can define each mapper in any order we want: + + {python} + >>> from sqlalchemy.orm import relation + + >>> mapper(User, users_table, properties={ # doctest: +ELLIPSIS + ... 'addresses':relation(Address, backref='user') + ... }) + + + >>> mapper(Address, addresses_table) # doctest: +ELLIPSIS + + +Above, the new thing we see is that `User` has defined a relation named `addresses`, which will reference a list of `Address` objects. How does it know it's a list ? SQLAlchemy figures it out for you, based on the foreign key relationship between `users_table` and `addresses_table`. + +## Working with Related Objects and Backreferences {@name=relation_backref} + +Now when we create a `User`, it automatically has this collection present: + + {python} + >>> jack = User('jack', 'Jack Bean', 'gjffdd') + >>> jack.addresses + [] + +We are free to add `Address` objects, and the `session` will take care of everything for us. + + {python} + >>> jack.addresses.append(Address(email_address='jack@google.com')) + >>> jack.addresses.append(Address(email_address='j25@yahoo.com')) + +Before we save into the `Session`, lets examine one other thing that's happened here. The `addresses` collection is present on our `User` because we added a `relation()` with that name. But also within the `relation()` function is the keyword `backref`. This keyword indicates that we wish to make a **bi-directional relationship**. What this basically means is that not only did we generate a one-to-many relationship called `addresses` on the `User` class, we also generated a **many-to-one** relationship on the `Address` class. This relationship is self-updating, without any data being flushed to the database, as we can see on one of Jack's addresses: + + {python} + >>> jack.addresses[1] + + + >>> jack.addresses[1].user + + +Let's save into the session, then close out the session and create a new one...so that we can see how `Jack` and his email addresses come back to us: + + {python} + >>> session.save(jack) + {sql}>>> session.commit() + BEGIN + INSERT INTO users (name, fullname, password) VALUES (?, ?, ?) + ['jack', 'Jack Bean', 'gjffdd'] + INSERT INTO addresses (email_address, user_id) VALUES (?, ?) + ['jack@google.com', 5] + INSERT INTO addresses (email_address, user_id) VALUES (?, ?) + ['j25@yahoo.com', 5] + COMMIT + + >>> session = Session() + +Querying for Jack, we get just Jack back. No SQL is yet issued for for Jack's addresses: + + {python} + {sql}>>> jack = session.query(User).filter_by(name='jack').one() + BEGIN + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name = ? ORDER BY users.oid + LIMIT 2 OFFSET 0 + ['jack'] + + >>> jack + + +Let's look at the `addresses` collection. Watch the SQL: + + {python} + {sql}>>> jack.addresses + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE ? = addresses.user_id ORDER BY addresses.oid + [5] + {stop}[, ] + +When we accessed the `addresses` collection, SQL was suddenly issued. This is an example of a **lazy loading relation**. + +If you want to reduce the number of queries (dramatically, in many cases), we can apply an **eager load** to the query operation. We clear out the session to ensure that a full reload occurs: + + {python} + >>> session.clear() + +Then apply an **option** to the query, indicating that we'd like `addresses` to load "eagerly". SQLAlchemy then constructs a join between the `users` and `addresses` tables: + + {python} + >>> from sqlalchemy.orm import eagerload + + {sql}>>> jack = session.query(User).options(eagerload('addresses')).filter_by(name='jack').one() #doctest: +NORMALIZE_WHITESPACE + SELECT anon_1.users_id AS anon_1_users_id, anon_1.users_name AS anon_1_users_name, + anon_1.users_fullname AS anon_1_users_fullname, anon_1.users_password AS anon_1_users_password, + addresses_1.id AS addresses_1_id, addresses_1.email_address AS addresses_1_email_address, + addresses_1.user_id AS addresses_1_user_id + FROM (SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, + users.password AS users_password, users.oid AS users_oid + FROM users + WHERE users.name = ? ORDER BY users.oid + LIMIT 2 OFFSET 0) AS anon_1 LEFT OUTER JOIN addresses AS addresses_1 + ON anon_1.users_id = addresses_1.user_id ORDER BY anon_1.oid, addresses_1.oid + ['jack'] + + >>> jack + + + >>> jack.addresses + [, ] + +If you think that query is elaborate, it is ! But SQLAlchemy is just getting started. Note that when using eager loading, *nothing* changes as far as the ultimate results returned. The "loading strategy", as it's called, is designed to be completely transparent in all cases, and is for optimization purposes only. Any query criterion you use to load objects, including ordering, limiting, other joins, etc., should return identical results regardless of the combination of lazily- and eagerly- loaded relationships present. + +An eagerload targeting across multiple relations can use dot separated names: + + {python} + query.options(eagerload('orders'), eagerload('orders.items'), eagerload('orders.items.keywords')) + +To roll up the above three individual `eagerload()` calls into one, use `eagerload_all()`: + + {python} + query.options(eagerload_all('orders.items.keywords')) + +## Querying with Joins {@name=joins} + +Which brings us to the next big topic. What if we want to create joins that *do* change the results ? For that, another `Query` tornado is coming.... + +One way to join two tables together is just to compose a SQL expression. Below we make one up using the `id` and `user_id` attributes on our mapped classes: + + {python} + {sql}>>> session.query(User).filter(User.id==Address.user_id).\ + ... filter(Address.email_address=='jack@google.com').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users, addresses + WHERE users.id = addresses.user_id AND addresses.email_address = ? ORDER BY users.oid + ['jack@google.com'] + {stop}[] + +Or we can make a real JOIN construct; below we use the `join()` function available on `Table` to create a `Join` object, then tell the `Query` to use it as our FROM clause: + + {python} + {sql}>>> session.query(User).select_from(users_table.join(addresses_table)).\ + ... filter(Address.email_address=='jack@google.com').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users JOIN addresses ON users.id = addresses.user_id + WHERE addresses.email_address = ? ORDER BY users.oid + ['jack@google.com'] + {stop}[] + +Note that the `join()` construct has no problem figuring out the correct join condition between `users_table` and `addresses_table`..the `ForeignKey` we constructed says it all. + +The easiest way to join is automatically, using the `join()` method on `Query`. Just give this method the path from A to B, using the name of a mapped relationship directly: + + {python} + {sql}>>> session.query(User).join('addresses').\ + ... filter(Address.email_address=='jack@google.com').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users JOIN addresses ON users.id = addresses.user_id + WHERE addresses.email_address = ? ORDER BY users.oid + ['jack@google.com'] + {stop}[] + +By "A to B", we mean a single relation name or a path of relations. In our case we only have `User->addresses->Address` configured, but if we had a setup like `A->bars->B->bats->C->widgets->D`, a join along all four entities would look like: + + {python} + session.query(Foo).join(['bars', 'bats', 'widgets']).filter(...) + +Each time `join()` is called on `Query`, the **joinpoint** of the query is moved to be that of the endpoint of the join. As above, when we joined from `users_table` to `addresses_table`, all subsequent criterion used by `filter_by()` are against the `addresses` table. When you `join()` again, the joinpoint starts back from the root. We can also backtrack to the beginning explicitly using `reset_joinpoint()`. This instruction will place the joinpoint back at the root `users` table, where subsequent `filter_by()` criterion are again against `users`: + + {python} + {sql}>>> session.query(User).join('addresses').\ + ... filter_by(email_address='jack@google.com').\ + ... reset_joinpoint().filter_by(name='jack').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users JOIN addresses ON users.id = addresses.user_id + WHERE addresses.email_address = ? AND users.name = ? ORDER BY users.oid + ['jack@google.com', 'jack'] + {stop}[] + +In all cases, we can get the `User` and the matching `Address` objects back at the same time, by telling the session we want both. This returns the results as a list of tuples: + + {python} + {sql}>>> session.query(User).add_entity(Address).join('addresses').\ + ... filter(Address.email_address=='jack@google.com').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password, addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM users JOIN addresses ON users.id = addresses.user_id + WHERE addresses.email_address = ? ORDER BY users.oid + ['jack@google.com'] + {stop}[(, )] + +Another common scenario is the need to join on the same table more than once. For example, if we want to find a `User` who has two distinct email addresses, both `jack@google.com` as well as `j25@yahoo.com`, we need to join to the `Addresses` table twice. SQLAlchemy does provide `Alias` objects which can accomplish this; but far easier is just to tell `join()` to alias for you: + + {python} + {sql}>>> session.query(User).\ + ... join('addresses', aliased=True).filter(Address.email_address=='jack@google.com').\ + ... join('addresses', aliased=True).filter(Address.email_address=='j25@yahoo.com').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id JOIN addresses AS addresses_2 ON users.id = addresses_2.user_id + WHERE addresses_1.email_address = ? AND addresses_2.email_address = ? ORDER BY users.oid + ['jack@google.com', 'j25@yahoo.com'] + {stop}[] + +The key thing which occurred above is that our SQL criterion were **aliased** as appropriate corresponding to the alias generated in the most recent `join()` call. + +The next section describes some "higher level" operators, including `any()` and `has()`, which make patterns like joining to multiple aliases unnecessary in most cases. + +### Relation Operators + +A summary of all operators usable on relations: + +* Filter on explicit column criterion, combined with a join. Column criterion can make usage of all supported SQL operators and expression constructs: + + {python} + {sql}>>> session.query(User).join('addresses').\ + ... filter(Address.email_address=='jack@google.com').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users JOIN addresses ON users.id = addresses.user_id + WHERE addresses.email_address = ? ORDER BY users.oid + ['jack@google.com'] + {stop}[] + + Criterion placed in `filter()` usually correspond to the last `join()` call; if the join was specified with `aliased=True`, class-level criterion against the join's target (or targets) will be appropriately aliased as well. + + {python} + {sql}>>> session.query(User).join('addresses', aliased=True).\ + ... filter(Address.email_address=='jack@google.com').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id + WHERE addresses_1.email_address = ? ORDER BY users.oid + ['jack@google.com'] + {stop}[] + +* Filter_by on key=value criterion, combined with a join. Same as `filter()` on column criterion except keyword arguments are used. + + {python} + {sql}>>> session.query(User).join('addresses').\ + ... filter_by(email_address='jack@google.com').all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users JOIN addresses ON users.id = addresses.user_id + WHERE addresses.email_address = ? ORDER BY users.oid + ['jack@google.com'] + {stop}[] + +* Filter on explicit column criterion using `any()` (for collections) or `has()` (for scalar relations). This is a more succinct method than joining, as an `EXISTS` subquery is generated automatically. `any()` means, "find all parent items where any child item of its collection meets this criterion": + + {python} + {sql}>>> session.query(User).\ + ... filter(User.addresses.any(Address.email_address=='jack@google.com')).all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE EXISTS (SELECT 1 + FROM addresses + WHERE users.id = addresses.user_id AND addresses.email_address = ?) ORDER BY users.oid + ['jack@google.com'] + {stop}[] + + `has()` means, "find all parent items where the child item meets this criterion": + + {python} + {sql}>>> session.query(Address).\ + ... filter(Address.user.has(User.name=='jack')).all() + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE EXISTS (SELECT 1 + FROM users + WHERE users.id = addresses.user_id AND users.name = ?) ORDER BY addresses.oid + ['jack'] + {stop}[, ] + + Both `has()` and `any()` also accept keyword arguments which are interpreted against the child classes' attributes: + + {python} + {sql}>>> session.query(User).\ + ... filter(User.addresses.any(email_address='jack@google.com')).all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE EXISTS (SELECT 1 + FROM addresses + WHERE users.id = addresses.user_id AND addresses.email_address = ?) ORDER BY users.oid + ['jack@google.com'] + {stop}[] + +* Filter_by on instance identity criterion. When comparing to a related instance, `filter_by()` will in most cases not need to reference the child table, since a child instance already contains enough information with which to generate criterion against the parent table. `filter_by()` uses an equality comparison for all relationship types. For many-to-one and one-to-one, this represents all objects which reference the given child object: + + {python} + # locate a user + {sql}>>> user = session.query(User).filter(User.name=='jack').one() #doctest: +NORMALIZE_WHITESPACE + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name = ? ORDER BY users.oid + LIMIT 2 OFFSET 0 + ['jack'] + {stop} + + # use the user in a filter_by() expression + {sql}>>> session.query(Address).filter_by(user=user).all() + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE ? = addresses.user_id ORDER BY addresses.oid + [5] + {stop}[, ] + + For one-to-many and many-to-many, it represents all objects which contain the given child object in the related collection: + + {python} + # locate an address + {sql}>>> address = session.query(Address).\ + ... filter(Address.email_address=='jack@google.com').one() #doctest: +NORMALIZE_WHITESPACE + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE addresses.email_address = ? ORDER BY addresses.oid + LIMIT 2 OFFSET 0 + {stop}['jack@google.com'] + + # use the address in a filter_by expression + {sql}>>> session.query(User).filter_by(addresses=address).all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.id = ? ORDER BY users.oid + [5] + {stop}[] + +* Select instances with a particular parent. This is the "reverse" operation of filtering by instance identity criterion; the criterion is against a relation pointing *to* the desired class, instead of one pointing *from* it. This will utilize the same "optimized" query criterion, usually not requiring any joins: + + {python} + {sql}>>> session.query(Address).with_parent(user, property='addresses').all() + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE ? = addresses.user_id ORDER BY addresses.oid + [5] + {stop}[, ] + +* Filter on a many-to-one/one-to-one instance identity criterion. The class-level `==` operator will act the same as `filter_by()` for a scalar relation: + + {python} + {sql}>>> session.query(Address).filter(Address.user==user).all() + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE ? = addresses.user_id ORDER BY addresses.oid + [5] + {stop}[, ] + + whereas the `!=` operator will generate a negated EXISTS clause: + + {python} + {sql}>>> session.query(Address).filter(Address.user!=user).all() + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE NOT (EXISTS (SELECT 1 + FROM users + WHERE users.id = addresses.user_id AND users.id = ?)) ORDER BY addresses.oid + [5] + {stop}[] + + a comparison to `None` also generates an IS NULL clause for a many-to-one relation: + + {python} + {sql}>>> session.query(Address).filter(Address.user==None).all() + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE addresses.user_id IS NULL ORDER BY addresses.oid + [] + {stop}[] + +* Filter on a one-to-many instance identity criterion. The `contains()` operator returns all parent objects which contain the given object as one of its collection members: + + {python} + {sql}>>> session.query(User).filter(User.addresses.contains(address)).all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.id = ? ORDER BY users.oid + [5] + {stop}[] + +* Filter on a multiple one-to-many instance identity criterion. The `==` operator can be used with a collection-based attribute against a list of items, which will generate multiple `EXISTS` clauses: + + {python} + {sql}>>> addresses = session.query(Address).filter(Address.user==user).all() + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE ? = addresses.user_id ORDER BY addresses.oid + [5] + {stop} + + {sql}>>> session.query(User).filter(User.addresses == addresses).all() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE (EXISTS (SELECT 1 + FROM addresses + WHERE users.id = addresses.user_id AND addresses.id = ?)) AND (EXISTS (SELECT 1 + FROM addresses + WHERE users.id = addresses.user_id AND addresses.id = ?)) ORDER BY users.oid + [1, 2] + {stop}[] + +## Deleting + +Let's try to delete `jack` and see how that goes. We'll mark as deleted in the session, then we'll issue a `count` query to see that no rows remain: + + {python} + >>> session.delete(jack) + {sql}>>> session.query(User).filter_by(name='jack').count() # doctest: +NORMALIZE_WHITESPACE + UPDATE addresses SET user_id=? WHERE addresses.id = ? + [None, 1] + UPDATE addresses SET user_id=? WHERE addresses.id = ? + [None, 2] + DELETE FROM users WHERE users.id = ? + [5] + SELECT count(users.id) AS count_1 + FROM users + WHERE users.name = ? + ['jack'] + {stop}0 + +So far, so good. How about Jack's `Address` objects ? + + {python} + {sql}>>> session.query(Address).filter( + ... Address.email_address.in_(['jack@google.com', 'j25@yahoo.com']) + ... ).count() # doctest: +NORMALIZE_WHITESPACE + SELECT count(addresses.id) AS count_1 + FROM addresses + WHERE addresses.email_address IN (?, ?) + ['jack@google.com', 'j25@yahoo.com'] + {stop}2 + +Uh oh, they're still there ! Analyzing the flush SQL, we can see that the `user_id` column of each address was set to NULL, but the rows weren't deleted. SQLAlchemy doesn't assume that deletes cascade, you have to tell it so. + +So let's rollback our work, and start fresh with new mappers that express the relationship the way we want: + + {python} + {sql}>>> session.rollback() # roll back the transaction + ROLLBACK + + >>> session.clear() # clear the session + >>> clear_mappers() # clear mappers + +We need to tell the `addresses` relation on `User` that we'd like session.delete() operations to cascade down to the child `Address` objects. Further, we also want `Address` objects which get detached from their parent `User`, whether or not the parent is deleted, to be deleted. For these behaviors we use two **cascade options** `delete` and `delete-orphan`, using the string-based `cascade` option to the `relation()` function: + + {python} + >>> mapper(User, users_table, properties={ # doctest: +ELLIPSIS + ... 'addresses':relation(Address, backref='user', cascade="all, delete, delete-orphan") + ... }) + + + >>> mapper(Address, addresses_table) # doctest: +ELLIPSIS + + +Now when we load Jack, removing an address from his `addresses` collection will result in that `Address` being deleted: + + {python} + # load Jack by primary key + {sql}>>> jack = session.query(User).get(jack.id) #doctest: +NORMALIZE_WHITESPACE + BEGIN + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.id = ? + [5] + {stop} + + # remove one Address (lazy load fires off) + {sql}>>> del jack.addresses[1] + SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id + FROM addresses + WHERE ? = addresses.user_id ORDER BY addresses.oid + [5] + {stop} + + # only one address remains + {sql}>>> session.query(Address).filter( + ... Address.email_address.in_(['jack@google.com', 'j25@yahoo.com']) + ... ).count() # doctest: +NORMALIZE_WHITESPACE + DELETE FROM addresses WHERE addresses.id = ? + [2] + SELECT count(addresses.id) AS count_1 + FROM addresses + WHERE addresses.email_address IN (?, ?) + ['jack@google.com', 'j25@yahoo.com'] + {stop}1 + +Deleting Jack will delete both Jack and his remaining `Address`: + + {python} + >>> session.delete(jack) + + {sql}>>> session.commit() + DELETE FROM addresses WHERE addresses.id = ? + [1] + DELETE FROM users WHERE users.id = ? + [5] + COMMIT + {stop} + + {sql}>>> session.query(User).filter_by(name='jack').count() # doctest: +NORMALIZE_WHITESPACE + BEGIN + SELECT count(users.id) AS count_1 + FROM users + WHERE users.name = ? + ['jack'] + {stop}0 + + {sql}>>> session.query(Address).filter( + ... Address.email_address.in_(['jack@google.com', 'j25@yahoo.com']) + ... ).count() # doctest: +NORMALIZE_WHITESPACE + SELECT count(addresses.id) AS count_1 + FROM addresses + WHERE addresses.email_address IN (?, ?) + ['jack@google.com', 'j25@yahoo.com'] + {stop}0 + +## Building a Many To Many Relation {@name=manytomany} + +We're moving into the bonus round here, but lets show off a many-to-many relationship. We'll sneak in some other features too, just to take a tour. We'll make our application a blog application, where users can write `BlogPost`s, which have `Keywords` associated with them. + +First some new tables: + + {python} + >>> from sqlalchemy import Text + >>> post_table = Table('posts', metadata, + ... Column('id', Integer, primary_key=True), + ... Column('user_id', Integer, ForeignKey('users.id')), + ... Column('headline', String(255), nullable=False), + ... Column('body', Text) + ... ) + + >>> post_keywords = Table('post_keywords', metadata, + ... Column('post_id', Integer, ForeignKey('posts.id')), + ... Column('keyword_id', Integer, ForeignKey('keywords.id'))) + + >>> keywords_table = Table('keywords', metadata, + ... Column('id', Integer, primary_key=True), + ... Column('keyword', String(50), nullable=False, unique=True)) + + {sql}>>> metadata.create_all(engine) # doctest: +NORMALIZE_WHITESPACE + PRAGMA table_info("users") + {} + PRAGMA table_info("addresses") + {} + PRAGMA table_info("posts") + {} + PRAGMA table_info("keywords") + {} + PRAGMA table_info("post_keywords") + {} + CREATE TABLE posts ( + id INTEGER NOT NULL, + user_id INTEGER, + headline VARCHAR(255) NOT NULL, + body TEXT, + PRIMARY KEY (id), + FOREIGN KEY(user_id) REFERENCES users (id) + ) + {} + COMMIT + CREATE TABLE keywords ( + id INTEGER NOT NULL, + keyword VARCHAR(50) NOT NULL, + PRIMARY KEY (id), + UNIQUE (keyword) + ) + {} + COMMIT + CREATE TABLE post_keywords ( + post_id INTEGER, + keyword_id INTEGER, + FOREIGN KEY(post_id) REFERENCES posts (id), + FOREIGN KEY(keyword_id) REFERENCES keywords (id) + ) + {} + COMMIT + +Then some classes: + + {python} + >>> class BlogPost(object): + ... def __init__(self, headline, body, author): + ... self.author = author + ... self.headline = headline + ... self.body = body + ... def __repr__(self): + ... return "BlogPost(%r, %r, %r)" % (self.headline, self.body, self.author) + + >>> class Keyword(object): + ... def __init__(self, keyword): + ... self.keyword = keyword + +And the mappers. `BlogPost` will reference `User` via its `author` attribute: + + {python} + >>> from sqlalchemy.orm import backref + + >>> mapper(Keyword, keywords_table) # doctest: +ELLIPSIS + + + >>> mapper(BlogPost, post_table, properties={ # doctest: +ELLIPSIS + ... 'author':relation(User, backref=backref('posts', lazy='dynamic')), + ... 'keywords':relation(Keyword, secondary=post_keywords) + ... }) + + +There's three new things in the above mapper: + + * the `User` relation has a backref, like we've used before, except this time it references a function called `backref()`. This function is used when yo'd like to specify keyword options for the backwards relationship. + * the keyword option we specified to `backref()` is `lazy="dynamic"`. This sets a default **loader strategy** on the attribute, in this case a special strategy that allows partial loading of results. + * The `keywords` relation uses a keyword argument `secondary` to indicate the **association table** for the many to many relationship from `BlogPost` to `Keyword`. + +Usage is not too different from what we've been doing. Let's give Wendy some blog posts: + + {python} + {sql}>>> wendy = session.query(User).filter_by(name='wendy').one() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password + FROM users + WHERE users.name = ? ORDER BY users.oid + LIMIT 2 OFFSET 0 + ['wendy'] + + >>> post = BlogPost("Wendy's Blog Post", "This is a test", wendy) + >>> session.save(post) + +We're storing keywords uniquely in the database, but we know that we don't have any yet, so we can just create them: + + {python} + >>> post.keywords.append(Keyword('wendy')) + >>> post.keywords.append(Keyword('firstpost')) + +We can now look up all blog posts with the keyword 'firstpost'. We'll use a special collection operator `any` to locate "blog posts where any of its keywords has the keyword string 'firstpost'": + + {python} + {sql}>>> session.query(BlogPost).filter(BlogPost.keywords.any(keyword='firstpost')).all() + INSERT INTO keywords (keyword) VALUES (?) + ['wendy'] + INSERT INTO keywords (keyword) VALUES (?) + ['firstpost'] + INSERT INTO posts (user_id, headline, body) VALUES (?, ?, ?) + [2, "Wendy's Blog Post", 'This is a test'] + INSERT INTO post_keywords (post_id, keyword_id) VALUES (?, ?) + [[1, 1], [1, 2]] + SELECT posts.id AS posts_id, posts.user_id AS posts_user_id, posts.headline AS posts_headline, posts.body AS posts_body + FROM posts + WHERE EXISTS (SELECT 1 + FROM post_keywords, keywords + WHERE posts.id = post_keywords.post_id AND keywords.id = post_keywords.keyword_id AND keywords.keyword = ?) ORDER BY posts.oid + ['firstpost'] + {stop}[BlogPost("Wendy's Blog Post", 'This is a test', )] + +If we want to look up just Wendy's posts, we can tell the query to narrow down to her as a parent: + + {python} + {sql}>>> session.query(BlogPost).with_parent(wendy).\ + ... filter(BlogPost.keywords.any(keyword='firstpost')).all() + SELECT posts.id AS posts_id, posts.user_id AS posts_user_id, posts.headline AS posts_headline, posts.body AS posts_body + FROM posts + WHERE ? = posts.user_id AND (EXISTS (SELECT 1 + FROM post_keywords, keywords + WHERE posts.id = post_keywords.post_id AND keywords.id = post_keywords.keyword_id AND keywords.keyword = ?)) ORDER BY posts.oid + [2, 'firstpost'] + {stop}[BlogPost("Wendy's Blog Post", 'This is a test', )] + +Or we can use Wendy's own `posts` relation, which is a "dynamic" relation, to query straight from there: + + {python} + {sql}>>> wendy.posts.filter(BlogPost.keywords.any(keyword='firstpost')).all() + SELECT posts.id AS posts_id, posts.user_id AS posts_user_id, posts.headline AS posts_headline, posts.body AS posts_body + FROM posts + WHERE ? = posts.user_id AND (EXISTS (SELECT 1 + FROM post_keywords, keywords + WHERE posts.id = post_keywords.post_id AND keywords.id = post_keywords.keyword_id AND keywords.keyword = ?)) ORDER BY posts.oid + [2, 'firstpost'] + {stop}[BlogPost("Wendy's Blog Post", 'This is a test', )] + +## Further Reference + +Generated Documentation for Query: [docstrings_sqlalchemy.orm.query_Query](rel:docstrings_sqlalchemy.orm.query_Query) + +ORM Generated Docs: [docstrings_sqlalchemy.orm](rel:docstrings_sqlalchemy.orm) + +Further information on mapping setups are in [advdatamapping](rel:advdatamapping). + +Further information on working with Sessions: [unitofwork](rel:unitofwork). diff --git a/doc/build/content/plugins.txt b/doc/build/content/plugins.txt index dbc85a6f99..c144826d7b 100644 --- a/doc/build/content/plugins.txt +++ b/doc/build/content/plugins.txt @@ -1,280 +1,206 @@ Plugins {@name=plugins} ====================== -SQLAlchemy has a variety of extensions and "mods" available which provide extra functionality to SA, either via explicit usage or by augmenting the core behavior. Several of these extensions are designed to work together. +SQLAlchemy has a variety of extensions available which provide extra functionality to SA, either via explicit usage or by augmenting the core behavior. Several of these extensions are designed to work together. -### SessionContext +### declarative -**Author:** Daniel Miller +**Author:** Mike Bayer
+**Version:** 0.4.4 or greater -This plugin is used to instantiate and manage Session objects. It is the preferred way to provide thread-local session functionality to an application. It provides several services: +`declarative` intends to be a fully featured replacement for the very old `activemapper` extension. Its goal is to redefine the organization of class, `Table`, and `mapper()` constructs such that they can all be defined "at once" underneath a class declaration. Unlike `activemapper`, it does not redefine normal SQLAlchemy configurational semantics - regular `Column`, `relation()` and other schema or ORM constructs are used in almost all cases. -* serves as a factory to create sessions of a particular configuration. This factory may either call `create_session()` with a particular set of arguments, or instantiate a different implementation of `Session` if one is available. -* for the `Session` objects it creates, provides the ability to maintain a single `Session` per distinct application thread. The `Session` returned by a `SessionContext` is called the *contextual session.* Providing at least a thread-local context to sessions is important because the `Session` object is not threadsafe, and is intended to be used with localized sets of data, as opposed to a single session being used application wide. -* besides maintaining a single `Session` per thread, the contextual algorithm can be changed to support any kind of contextual scheme. -* provides a `MapperExtension` that can enhance a `Mapper`, such that it can automatically `save()` newly instantiated objects to the current contextual session. It also allows `Query` objects to be created without an explicit `Session`. While this is very convenient functionality, having it switched on without understanding it can be very confusing. Note that this feature is optional when using `SessionContext`. +`declarative` is a so-called "micro declarative layer"; it does not generate table or column names and requires almost as fully verbose a configuration as that of straight tables and mappers. As an alternative, the [Elixir](http://elixir.ematia.de/) project is a full community-supported declarative layer for SQLAlchemy, and is recommended for its active-record-like semantics, its convention-based configuration, and plugin capabilities. -Using the SessionContext in its most basic form involves just instantiating a `SessionContext`: +SQLAlchemy object-relational configuration involves the usage of Table, mapper(), and class objects to define the three areas of configuration. +declarative moves these three types of configuration underneath the individual mapped class. Regular SQLAlchemy schema and ORM constructs are used +in most cases: {python} - import sqlalchemy - from sqlalchemy.ext.sessioncontext import SessionContext - - ctx = SessionContext(sqlalchemy.create_session) + from sqlalchemy.ext.declarative import declarative_base - class User(object): - pass + Base = declarative_base() - mapper(User, users_table) - u = User() - - # the contextual session is referenced by the "current" property on SessionContext - ctx.current.save(u) - ctx.current.flush() + class SomeClass(Base): + __tablename__ = 'some_table' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) -From this example, one might see that the `SessionContext`'s typical *scope* is at the module or application level. Since the `Session` itself is better suited to be used in per-user-request or even per-function scope, the `SessionContext` provides an easy way to manage the scope of those `Session` objects. +Above, the `declarative_base` callable produces a new base class from which all mapped classes inherit from. When the class definition is +completed, a new `Table` and `mapper()` have been generated, accessible via the `__table__` and `__mapper__` attributes on the +`SomeClass` class. -The construction of each `Session` instance can be customized by providing a "creation function" which returns a new `Session`. A common customization is a `Session` which needs to explicitly bind to a particular `Engine`: +Attributes may be added to the class after its construction, and they will be added to the underlying `Table` and `mapper()` definitions as +appropriate: {python} - import sqlalchemy - from sqlalchemy.ext.sessioncontext import SessionContext - - # create an engine - someengine = sqlalchemy.create_engine('sqlite:///') - - # a function to return a Session bound to our engine - def make_session(): - return sqlalchemy.create_session(bind_to=someengine) - - # SessionContext - ctx = SessionContext(make_session) - - # get the session bound to engine "someengine": - session = ctx.current - -The above pattern is more succinctly expressed using Python lambdas: + SomeClass.data = Column('data', Unicode) + SomeClass.related = relation(RelatedInfo) - {python} - ctx = SessionContext(lambda:sqlalchemy.create_session(bind_to=someengine)) +Classes which are mapped explicitly using `mapper()` can interact freely with declarative classes. -The default creation function is simply: +The `declarative_base` base class contains a `MetaData` object where newly defined `Table` objects are collected. This is accessed via the ``metadata`` class level accessor, so to create tables we can say: {python} - ctx = SessionContext(sqlalchemy.create_session) + engine = create_engine('sqlite://') + Base.metadata.create_all(engine) -The "scope" to which the session is associated, which by default is a thread-local scope, can be customized by providing a "scope callable" which returns a hashable key that represents the current scope: +The `Engine` created above may also be directly associated with the declarative base class using the `engine` keyword argument, where it will be associated with the underlying `MetaData` object and allow SQL operations involving that metadata and its tables to make use of that engine automatically: {python} - import sqlalchemy - from sqlalchemy.ext.sessioncontext import SessionContext - - # global declaration of "scope" - scope = "scope1" - - # a function to return the current "session scope" - def global_scope_func(): - return scope + Base = declarative_base(engine=create_engine('sqlite://')) - # create SessionContext with a custom "scopefunc" - ctx = SessionContext(sqlalchemy.create_session, scopefunc=global_scope_func) - - # get the session corresponding to "scope1": - session = ctx.current - - # switch the "scope" - scope = "scope2" - - # get the session corresponding to "scope2": - session = ctx.current - -Examples of customized scope can include user-specific sessions or requests, or even sub-elements of an application, such as a graphical application which maintains a single `Session` per application window (this was the original motivation to create SessionContext). +Or, as `MetaData` allows, at any time using the `bind` attribute: -#### Using SessionContextExt {@name=sessioncontextext} - -This is a `MapperExtension` which allows a `Mapper` to be automatically associated with a `SessionContext`. Newly constructed objects get `save()`d to the session automatically, and `Query` objects can be constructed without a session. The instance of `SessionContextExt` is provided by the `SessionContext` itself: + {python} + Base.metadata.bind = create_engine('sqlite://') + +The `declarative_base` can also receive a pre-created `MetaData` object, which allows a declarative setup to be associated with an already existing traditional collection of `Table` objects: {python} - import sqlalchemy - from sqlalchemy.ext.sessioncontext import SessionContext - - ctx = SessionContext(sqlalchemy.create_session) - - class User(object): - pass - - mapper(User, users_table, extension=ctx.mapper_extension) + mymetadata = MetaData() + Base = declarative_base(metadata=mymetadata) - # 'u' is automatically added to the current session of 'ctx' - u = User() - - assert u in ctx.current - - # get the current session and flush - ctx.current.flush() - -The `MapperExtension` can be configured either per-mapper as above, or on an application-wide basis using: +Relations to other classes are done in the usual way, with the added feature that the class specified to `relation()` may be a string name. The +"class registry" associated with `Base` is used at mapper compilation time to resolve the name into the actual class object, which is expected to +have been defined once the mapper configuration is used: {python} - import sqlalchemy - from sqlalchemy.orm.mapper import global_extensions - from sqlalchemy.ext.sessioncontext import SessionContext + class User(Base): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") - ctx = SessionContext(sqlalchemy.create_session) + class Address(Base): + __tablename__ = 'addresses' - global_extensions.append(ctx.mapper_extension) + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) -SessionContextExt allows `Query` objects to be created against the mapped class without specifying a `Session`. Each `Query` will automatically make usage of the current contextual session: +Column constructs, since they are just that, are immediately usable, as below where we define a primary join condition on the `Address` class +using them: {python} - # create a Query from a class - query = Query(User) - - # specify entity name - query = Query(User, entity_name='foo') - - # create a Query from a mapper - query = Query(mapper) + class Address(Base) + __tablename__ = 'addresses' - # then use it - result = query.select() - -When installed globally, all `Mapper` objects will contain a built-in association to the `SessionContext`. This means that once a mapped instance is created, creating a new `Session` and calling `save()` with the instance as an argument will raise an error stating that the instance is already associated with a different session. While you can always remove the object from its original session, `SessionContextExt` is probably convenient only for an application that does not need much explicit manipulation of sessions. + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + user = relation(User, primaryjoin=user_id==User.id) -The user still has some control over which session gets used at instance construction time. An instance can be redirected at construction time to a different `Session` by specifying the keyword parameter `_sa_session` to its constructor, which is decorated by the mapper: +When an explicit join condition or other configuration which depends +on multiple classes cannot be defined immediately due to some classes +not yet being available, these can be defined after all classes have +been created. Attributes which are added to the class after +its creation are associated with the Table/mapping in the same +way as if they had been defined inline: {python} - session = create_session() # create a new session distinct from the contextual session - myuser = User(_sa_session=session) # make a new User that is saved to this session + User.addresses = relation(Address, primaryjoin=Address.user_id==User.id) -Similarly, the `entity_name` parameter, which specifies an alternate `Mapper` to be used when attaching this instance to the `Session`, can be specified via `_sa_entity_name`: +Synonyms are one area where `declarative` needs to slightly change the usual SQLAlchemy configurational syntax. To define a +getter/setter which proxies to an underlying attribute, use `synonym` with the `instruments` argument: {python} - myuser = User(_sa_session=session, _sa_entity_name='altentity') - -The decoration of mapped instances' `__init__()` method is similar to this example: + class MyClass(Base): + __tablename__ = 'sometable' + + _attr = Column('attr', String) + + def _get_attr(self): + return self._some_attr + def _set_attr(self, attr) + self._some_attr = attr + attr = synonym('_attr', instruments=property(_get_attr, _set_attr)) + +The above synonym is then usable as an instance attribute as well as a class-level expression construct: {python} - oldinit = class_.__init__ # the previous init method - def __init__(self, *args, **kwargs): - session = kwargs.pop('_sa_session', None) - entity_name = kwargs.pop('_sa_entity_name', None) - if session is None: - session = ext.get_session() # get Session from this Mapper's MapperExtension - if session is EXT_PASS: - session = None - if session is not None: - session.save(self, entity_name=entity_name) # attach to the current session - oldinit(self, *args, **kwagrs) # call previous init method - -### SelectResults - -**Author:** Jonas Borgström - -*NOTE:* As of verison 0.3.6 of SQLAlchemy, most behavior of `SelectResults` has been rolled into the base `Query` object. Explicit usage of `SelectResults` is therefore no longer needed. + x = MyClass() + x.attr = "some value" + session.query(MyClass).filter(MyClass.attr == 'some other value').all() -`SelectResults` gives transformative behavior to the results returned from the `select` and `select_by` methods of `Query`. +The `synonyn_for` decorator can accomplish the same task: {python} - from sqlalchemy.ext.selectresults import SelectResults + class MyClass(Base): + __tablename__ = 'sometable' + + _attr = Column('attr', String) - query = session.query(MyClass) - res = SelectResults(query) - - res = res.filter(table.c.column == "something") # adds a WHERE clause (or appends to the existing via "and") - res = res.order_by([table.c.column]) # adds an ORDER BY clause + @synonyn_for('_attr') + @property + def attr(self): + return self._some_attr - for x in res[:10]: # Fetch and print the top ten instances - adds OFFSET 0 LIMIT 10 or equivalent - print x.column2 +Similarly, `comparable_using` is a front end for the `comparable_property` ORM function: - # evaluate as a list, which executes the query - x = list(res) + {python} + class MyClass(Base): + __tablename__ = 'sometable' - # Count how many instances that have column2 > 42 - # and column == "something" - print res.filter(table.c.column2 > 42).count() + name = Column('name', String) - # select() is a synonym for filter() - session.query(MyClass).select(mytable.c.column=="something").order_by([mytable.c.column])[2:7] + @comparable_using(MyUpperCaseComparator) + @property + def uc_name(self): + return self.name.upper() -An important facet of SelectResults is that the actual SQL execution does not occur until the object is used in a list or iterator context. This means you can call any number of transformative methods (including `filter`, `order_by`, list range expressions, etc) before any SQL is actually issued. +As an alternative to `__tablename__`, a direct `Table` construct may be used. The `Column` objects, which in this case require their names, will be added to the mapping just like a regular mapping to a table: -Configuration of SelectResults may be per-Query, per Mapper, or per application: {python} - from sqlalchemy.ext.selectresults import SelectResults, SelectResultsExt - - # construct a SelectResults for an individual Query - sel = SelectResults(session.query(MyClass)) - - # construct a Mapper where the Query.select()/select_by() methods will return a SelectResults: - mapper(MyClass, mytable, extension=SelectResultsExt()) - - # globally configure all Mappers to return SelectResults, using the "selectresults" mod - import sqlalchemy.mods.selectresults + class MyClass(Base): + __table__ = Table('my_table', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) -SelectResults greatly enhances querying and is highly recommended. For example, heres an example of constructing a query using a combination of joins and outerjoins: +This is the preferred approach when using reflected tables, as below: {python} - mapper(User, users_table, properties={ - 'orders':relation(mapper(Order, orders_table, properties={ - 'items':relation(mapper(Item, items_table)) - })) - }) - session = create_session() - query = SelectResults(session.query(User)) - - result = query.outerjoin_to('orders').outerjoin_to('items').select(or_(Order.c.order_id==None,Item.c.item_id==2)) - -For a full listing of methods, see the [generated documentation](rel:docstrings_sqlalchemy.ext.selectresults). - -### assignmapper + class MyClass(Base): + __table__ = Table('my_table', Base.metadata, autoload=True) -**Author:** Mike Bayer - -This extension is used to decorate a mapped class with direct knowledge about its own `Mapper`, a contextual `Session`, as well as functions provided by the `Query` and `Session` objects. The methods will automatically make usage of a contextual session with which all newly constructed objects are associated. `assign_mapper` operates as a `MapperExtension`, and requires the usage of a `SessionContext` as well as `SessionContextExt`, described in [plugins_sessioncontext](rel:plugins_sessioncontext). It replaces the usage of the normal `mapper` function with its own version that adds a `SessionContext` specified as the first argument: +Mapper arguments are specified using the `__mapper_args__` class variable. Note that the column objects declared on the class are immediately +usable, as in this joined-table inheritance example: {python} - import sqlalchemy - from sqlalchemy.ext.sessioncontext import SessionContext - from sqlalchemy.ext.assignmapper import assign_mapper - - # session context - ctx = SessionContext(sqlalchemy.create_session) - - # assign mapper to class MyClass using table 'sometable', getting - # Sessions from 'ctx'. - assign_mapper(ctx, MyClass, sometable, properties={...}, ...) + class Person(Base): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity':'engineer'} + id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column('primary_language', String(50)) + +For single-table inheritance, the `__tablename__` and `__table__` class variables are optional on a class when the class inherits from another +mapped class. -Above, all new instances of `MyClass` will be associated with the contextual session, `ctx.current`. Additionally, `MyClass` and instances of `MyClass` now contain a large set of methods including `get`, `select`, `flush`, `delete`. The full list is as follows: +As a convenience feature, the `declarative_base()` sets a default constructor on classes which takes keyword arguments, and assigns them to the +named attributes: {python} - # Query methods: - ['get', 'select', 'select_by', 'selectone', 'get_by', 'join_to', 'join_via', 'count', 'count_by'] - - # Session methods: - ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update'] + e = Engineer(primary_language='python') -To continue the `MyClass` example: +Note that `declarative` has no integration built in with sessions, and is only intended as an optional syntax for the regular usage of mappers +and Table objects. A typical application setup using `scoped_session` might look like: {python} - # create a MyClass. it will be automatically assigned to the contextual Session. - mc = MyClass() + engine = create_engine('postgres://scott:tiger@localhost/test') + Session = scoped_session(sessionmaker(transactional=True, autoflush=False, bind=engine)) + Base = declarative_base() - # save MyClass - this will call flush() on the session, specifying 'mc' as the only - # object to be affected - mc.flush() - - # load an object, using Query methods attached to MyClass - result = MyClass.get_by(id=5) - - # delete it - result.delete() - - # commit all changes - ctx.current.flush() +Mapped instances then make usage of `Session` in the usual way. -**Note:** : while the `flush()` method is also available on individual object instances, the instance-local flush() **does not flush dependent objects**. For this reason this method may be removed in a future release and replaced with a more explicit version. ### associationproxy @@ -285,6 +211,8 @@ To continue the `MyClass` example: #### Simplifying Relations +Consider this "association object" mapping: + {python} users_table = Table('users', metadata, Column('id', Integer, primary_key=True), @@ -311,10 +239,10 @@ To continue the `MyClass` example: def __init__(self, keyword): self.keyword = keyword - mapper(User, users, properties={ - 'kw': relation(Keyword, secondary=userkeywords) + mapper(User, users_table, properties={ + 'kw': relation(Keyword, secondary=userkeywords_table) }) - mapper(Keyword, keywords) + mapper(Keyword, keywords_table) Above are three simple tables, modeling users, keywords and a many-to-many relationship between the two. These ``Keyword`` objects are little more than a container for a name, and accessing them via the relation is awkward: @@ -419,7 +347,7 @@ Association proxies are also useful for keeping [association objects](rel:datama # Adding a Keyword requires creating a UserKeyword association object user.user_keywords.append(UserKeyword(user, kw1)) - # And accessing Keywords requires traverrsing UserKeywords + # And accessing Keywords requires traversing UserKeywords print user.user_keywords[0] # <__main__.UserKeyword object at 0xb79bbbec> @@ -611,13 +539,110 @@ index to any value you require. See the [module documentation](rel:docstrings_sqlalchemy.ext.orderinglist) for more information, and also check out the unit tests for examples of stepped numbering, alphabetical and Fibonacci numbering. + +### SqlSoup + +**Author:** Jonathan Ellis + +SqlSoup creates mapped classes on the fly from tables, which are automatically reflected from the database based on name. It is essentially a nicer version of the "row data gateway" pattern. + + {python} + >>> from sqlalchemy.ext.sqlsoup import SqlSoup + >>> soup = SqlSoup('sqlite:///') + + >>> db.users.select(order_by=[db.users.c.name]) + [MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1), + MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)] + +Full SqlSoup documentation is on the [SQLAlchemy Wiki](http://www.sqlalchemy.org/trac/wiki/SqlSoup). + + +### Deprecated Extensions + +A lot of our extensions are deprecated. But this is a good thing. Why ? Because all of them have been refined and focused, and rolled into the core of SQLAlchemy. So they aren't removed, they've just graduated into fully integrated features. Below we describe a set of extensions which are present in 0.4 but are deprecated. + +#### SelectResults + +**Author:** Jonas Borgström + +*NOTE:* As of version 0.3.6 of SQLAlchemy, most behavior of `SelectResults` has been rolled into the base `Query` object. Explicit usage of `SelectResults` is therefore no longer needed. + +`SelectResults` gives transformative behavior to the results returned from the `select` and `select_by` methods of `Query`. + + {python} + from sqlalchemy.ext.selectresults import SelectResults + + query = session.query(MyClass) + res = SelectResults(query) + res = res.filter(table.c.column == "something") # adds a WHERE clause (or appends to the existing via "and") + res = res.order_by([table.c.column]) # adds an ORDER BY clause + + for x in res[:10]: # Fetch and print the top ten instances - adds OFFSET 0 LIMIT 10 or equivalent + print x.column2 + + # evaluate as a list, which executes the query + x = list(res) + + # Count how many instances that have column2 > 42 + # and column == "something" + print res.filter(table.c.column2 > 42).count() -### ActiveMapper + # select() is a synonym for filter() + session.query(MyClass).select(mytable.c.column=="something").order_by([mytable.c.column])[2:7] + +An important facet of SelectResults is that the actual SQL execution does not occur until the object is used in a list or iterator context. This means you can call any number of transformative methods (including `filter`, `order_by`, list range expressions, etc) before any SQL is actually issued. + +Configuration of SelectResults may be per-Query, per Mapper, or per application: + + {python} + from sqlalchemy.ext.selectresults import SelectResults, SelectResultsExt + + # construct a SelectResults for an individual Query + sel = SelectResults(session.query(MyClass)) + + # construct a Mapper where the Query.select()/select_by() methods will return a SelectResults: + mapper(MyClass, mytable, extension=SelectResultsExt()) + + # globally configure all Mappers to return SelectResults, using the "selectresults" mod + import sqlalchemy.mods.selectresults + +SelectResults greatly enhances querying and is highly recommended. For example, heres an example of constructing a query using a combination of joins and outerjoins: + + {python} + mapper(User, users_table, properties={ + 'orders':relation(mapper(Order, orders_table, properties={ + 'items':relation(mapper(Item, items_table)) + })) + }) + session = create_session() + query = SelectResults(session.query(User)) + + result = query.outerjoin_to('orders').outerjoin_to('items').select(or_(Order.c.order_id==None,Item.c.item_id==2)) + +For a full listing of methods, see the [generated documentation](rel:docstrings_sqlalchemy.ext.selectresults). + +#### SessionContext + +**Author:** Daniel Miller + +The `SessionContext` extension is still available in the 0.4 release of SQLAlchemy, but has been deprecated in favor of the [scoped_session()](rel:unitofwork_contextual) function, which provides a class-like object that constructs a `Session` on demand which references a thread-local scope. + +For docs on `SessionContext`, see the SQLAlchemy 0.3 documentation. + +#### assignmapper + +**Author:** Mike Bayer + +The `assignmapper` extension is still available in the 0.4 release of SQLAlchemy, but has been deprecated in favor of the [scoped_session()](rel:unitofwork_contextual) function, which provides a `mapper` callable that works similarly to `assignmapper`. + +For docs on `assignmapper`, see the SQLAlchemy 0.3 documentation. + +#### ActiveMapper **Author:** Jonathan LaCour -Please note that ActiveMapper has been deprecated in favor of [Elixir](http://elixir.ematia.de/), a more comprehensive solution to declarative mapping, of which Jonathan is a co-author. +Please note that ActiveMapper has been deprecated in favor of either [Elixir](http://elixir.ematia.de/), a comprehensive solution to declarative mapping, or [declarative](rel:plugins_declarative), a built in convenience tool which reorganizes `Table` and `mapper()` configuration. ActiveMapper is a so-called "declarative layer" which allows the construction of a class, a `Table`, and a `Mapper` all in one step: @@ -672,19 +697,3 @@ ActiveMapper is a so-called "declarative layer" which allows the construction of More discussion on ActiveMapper can be found at [Jonathan LaCour's Blog](http://cleverdevil.org/computing/35/declarative-mapping-with-sqlalchemy) as well as the [SQLAlchemy Wiki](http://www.sqlalchemy.org/trac/wiki/ActiveMapper). -### SqlSoup - -**Author:** Jonathan Ellis - -SqlSoup creates mapped classes on the fly from tables, which are automatically reflected from the database based on name. It is essentially a nicer version of the "row data gateway" pattern. - - {python} - >>> from sqlalchemy.ext.sqlsoup import SqlSoup - >>> soup = SqlSoup('sqlite:///') - - >>> db.users.select(order_by=[db.users.c.name]) - [MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1), - MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)] - -Full SqlSoup documentation is on the [SQLAlchemy Wiki](http://www.sqlalchemy.org/trac/wiki/SqlSoup). - diff --git a/doc/build/content/pooling.txt b/doc/build/content/pooling.txt index ea9df87517..0703ea9a20 100644 --- a/doc/build/content/pooling.txt +++ b/doc/build/content/pooling.txt @@ -48,11 +48,6 @@ Common options include: * recycle=-1 : if set to non -1, a number of seconds between connection recycling, which means upon checkout, if this timeout is surpassed the connection will be closed and replaced with a newly opened connection. - * auto_close_cursors = True : cursors, returned by connection.cursor(), are tracked and are - automatically closed when the connection is returned to the pool. some DBAPIs like MySQLDB - become unstable if cursors remain open. - * disallow_open_cursors = False : if auto_close_cursors is False, and disallow_open_cursors is True, - will raise an exception if an open cursor is detected upon connection checkin. If auto_close_cursors and disallow_open_cursors are both False, then no cursor processing occurs upon checkin. QueuePool options include: diff --git a/doc/build/content/session.txt b/doc/build/content/session.txt new file mode 100644 index 0000000000..0e94cef244 --- /dev/null +++ b/doc/build/content/session.txt @@ -0,0 +1,816 @@ + Using the Session {@name=unitofwork} +============ + +The [Mapper](rel:advdatamapping) is the entrypoint to the configurational API of the SQLAlchemy object relational mapper. But the primary object one works with when using the ORM is the [Session](rel:docstrings_sqlalchemy.orm.session_Session). + +## What does the Session do ? + +In the most general sense, the `Session` establishes all conversations with the database and represents a "holding zone" for all the mapped instances which you've loaded or created during its lifespan. It implements the [Unit of Work](http://martinfowler.com/eaaCatalog/unitOfWork.html) pattern, which means it keeps track of all changes which occur, and is capable of **flushing** those changes to the database as appropriate. Another important facet of the `Session` is that it's also maintaining **unique** copies of each instance, where "unique" means "only one object with a particular primary key" - this pattern is called the [Identity Map](http://martinfowler.com/eaaCatalog/identityMap.html). + +Beyond that, the `Session` implements an interface which let's you move objects in or out of the session in a variety of ways, it provides the entryway to a `Query` object which is used to query the database for data, it is commonly used to provide transactional boundaries (though this is optional), and it also can serve as a configurational "home base" for one or more `Engine` objects, which allows various vertical and horizontal partitioning strategies to be achieved. + +## Getting a Session + +The `Session` object exists just as a regular Python object, which can be directly instantiated. However, it takes a fair amount of keyword options, several of which you probably want to set explicitly. It's fairly inconvenient to deal with the "configuration" of a session every time you want to create one. Therefore, SQLAlchemy recommends the usage of a helper function called `sessionmaker()`, which typically you call only once for the lifespan of an application. This function creates a customized `Session` subclass for you, with your desired configurational arguments pre-loaded. Then, whenever you need a new `Session`, you use your custom `Session` class with no arguments to create the session. + +### Using a sessionmaker() Configuration {@name=sessionmaker} + +The usage of `sessionmaker()` is illustrated below: + + {python} + from sqlalchemy.orm import sessionmaker + + # create a configured "Session" class + Session = sessionmaker(autoflush=True, transactional=True) + + # create a Session + sess = Session() + + # work with sess + sess.save(x) + sess.commit() + + # close when finished + sess.close() + +Above, the `sessionmaker` call creates a class for us, which we assign to the name `Session`. This class is a subclass of the actual `sqlalchemy.orm.session.Session` class, which will instantiate with the arguments of `autoflush=True` and `transactional=True`. + +When you write your application, place the call to `sessionmaker()` somewhere global, and then make your new `Session` class available to the rest of your application. + +### Binding Session to an Engine or Connection {@name=binding} + +In our previous example regarding `sessionmaker()`, nowhere did we specify how our session would connect to our database. When the session is configured in this manner, it will look for a database engine to connect with via the `Table` objects that it works with - the chapter called [metadata_tables_binding](rel:metadata_tables_binding) describes how to associate `Table` objects directly with a source of database connections. + +However, it is often more straightforward to explicitly tell the session what database engine (or engines) you'd like it to communicate with. This is particularly handy with multiple-database scenarios where the session can be used as the central point of configuration. To achieve this, the constructor keyword `bind` is used for a basic single-database configuration: + + {python} + # create engine + engine = create_engine('postgres://...') + + # bind custom Session class to the engine + Session = sessionmaker(bind=engine, autoflush=True, transactional=True) + + # work with the session + sess = Session() + +One common issue with the above scenario is that an application will often organize its global imports before it ever connects to a database. Since the `Session` class created by `sessionmaker()` is meant to be a global application object (note we are saying the session *class*, not a session *instance*), we may not have a `bind` argument available. For this, the `Session` class returned by `sessionmaker()` supports post-configuration of all options, through its method `configure()`: + + {python} + # configure Session class with desired options + Session = sessionmaker(autoflush=True, transactional=True) + + # later, we create the engine + engine = create_engine('postgres://...') + + # associate it with our custom Session class + Session.configure(bind=engine) + + # work with the session + sess = Session() + +The `Session` also has the ability to be bound to multiple engines. Descriptions of these scenarios are described in [unitofwork_partitioning](rel:unitofwork_partitioning). + + +#### Binding Session to a Connection {@name=connection} + +The examples involving `bind` so far are dealing with the `Engine` object, which is, like the `Session` class itself, a global configurational object. The `Session` can also be bound to an individual database `Connection`. The reason you might want to do this is if your application controls the boundaries of transactions using distinct `Transaction` objects (these objects are described in [dbengine_transactions](rel:dbengine_transactions)). You'd have a transactional `Connection`, and then you'd want to work with an ORM-level `Session` which participates in that transaction. Since `Connection` is definitely not a globally-scoped object in all but the most rudimental commandline applications, you can bind an individual `Session()` instance to a particular `Connection` not at class configuration time, but at session instance construction time: + + {python} + # global application scope. create Session class, engine + Session = sessionmaker(autoflush=True, transactional=True) + + engine = create_engine('postgres://...') + + ... + + # local scope, such as within a controller function + + # connect to the database + connection = engine.connect() + + # bind an individual Session to the connection + sess = Session(bind=connection) + +### Using create_session() {@name=createsession} + +As an alternative to `sessionmaker()`, `create_session()` exists literally as a function which calls the normal `Session` constructor directly. All arguments are passed through and the new `Session` object is returned: + + {python} + session = create_session(bind=myengine) + +The `create_session()` function doesn't add any functionality to the regular `Session`, it just sets up a default argument set of `autoflush=False, transactional=False`. But also, by calling `create_session()` instead of instantiating `Session` directly, you leave room in your application to change the type of session which the function creates. For example, an application which is calling `create_session()` in many places, which is typical for a pre-0.4 application, can be changed to use a `sessionmaker()` by just assigning the return of `sessionmaker()` to the `create_session` name: + + {python} + # change from: + from sqlalchemy.orm import create_session + + # to: + create_session = sessionmaker() + +## Using the Session + +A typical session conversation starts with creating a new session, or acquiring one from an ongoing context. You save new objects and load existing ones, make changes, mark some as deleted, and then persist your changes to the database. If your session is transactional, you use `commit()` to persist any remaining changes and to commit the transaction. If not, you call `flush()` which will flush any remaining data to the database. + +Below, we open a new `Session` using a configured `sessionmaker()`, make some changes, and commit: + + {python} + # configured Session class + Session = sessionmaker(autoflush=True, transactional=True) + + sess = Session() + d = Data(value=10) + sess.save(d) + d2 = sess.query(Data).filter(Data.value==15).one() + d2.value = 19 + sess.commit() + +### Quickie Intro to Object States {@name=states} + +It's helpful to know the states which an instance can have within a session: + +* *Transient* - an instance that's not in a session, and is not saved to the database; i.e. it has no database identity. The only relationship such an object has to the ORM is that its class has a `mapper()` associated with it. + +* *Pending* - when you `save()` a transient instance, it becomes pending. It still wasn't actually flushed to the database yet, but it will be when the next flush occurs. + +* *Persistent* - An instance which is present in the session and has a record in the database. You get persistent instances by either flushing so that the pending instances become persistent, or by querying the database for existing instances (or moving persistent instances from other sessions into your local session). + +* *Detached* - an instance which has a record in the database, but is not in any session. Theres nothing wrong with this, and you can use objects normally when they're detached, **except** they will not be able to issue any SQL in order to load collections or attributes which are not yet loaded, or were marked as "expired". + +Knowing these states is important, since the `Session` tries to be strict about ambiguous operations (such as trying to save the same object to two different sessions at the same time). + +### Frequently Asked Questions {@name=faq} + +* When do I make a `sessionmaker` ? + + Just one time, somewhere in your application's global scope. It should be looked upon as part of your application's configuration. If your application has three .py files in a package, you could, for example, place the `sessionmaker` line in your `__init__.py` file; from that point on your other modules say "from mypackage import Session". That way, everyone else just uses `Session()`, and the configuration of that session is controlled by that central point. + + If your application starts up, does imports, but does not know what database it's going to be connecting to, you can bind the `Session` at the "class" level to the engine later on, using `configure()`. + + In the examples in this section, we will frequently show the `sessionmaker` being created right above the line where we actually invoke `Session()`. But that's just for example's sake ! In reality, the `sessionmaker` would be somewhere at the module level, and your individual `Session()` calls would be sprinkled all throughout your app, such as in a web application within each controller method. + +* When do I make a `Session` ? + + You typically invoke `Session()` when you first need to talk to your database, and want to save some objects or load some existing ones. Then, you work with it, save your changes, and then dispose of it....or at the very least `close()` it. It's not a "global" kind of object, and should be handled more like a "local variable", as it's generally **not** safe to use with concurrent threads. Sessions are very inexpensive to make, and don't use any resources whatsoever until they are first used...so create some ! + + There is also a pattern whereby you're using a **contextual session**, this is described later in [unitofwork_contextual](rel:unitofwork_contextual). In this pattern, a helper object is maintaining a `Session` for you, most commonly one that is local to the current thread (and sometimes also local to an application instance). SQLAlchemy 0.4 has worked this pattern out such that it still *looks* like you're creating a new session as you need one...so in that case, it's still a guaranteed win to just say `Session()` whenever you want a session. + +* Is the Session a cache ? + + Yeee...no. It's somewhat used as a cache, in that it implements the identity map pattern, and stores objects keyed to their primary key. However, it doesn't do any kind of query caching. This means, if you say `session.query(Foo).filter_by(name='bar')`, even if `Foo(name='bar')` is right there, in the identity map, the session has no idea about that. It has to issue SQL to the database, get the rows back, and then when it sees the primary key in the row, *then* it can look in the local identity map and see that the object is already there. It's only when you say `query.get({some primary key})` that the `Session` doesn't have to issue a query. + + Additionally, the Session stores object instances using a weak reference by default. This also defeats the purpose of using the Session as a cache, unless the `weak_identity_map` flag is set to `False`. + + The `Session` is not designed to be a global object from which everyone consults as a "registry" of objects. That is the job of a **second level cache**. A good library for implementing second level caching is [Memcached](http://www.danga.com/memcached/). It *is* possible to "sort of" use the `Session` in this manner, if you set it to be non-transactional and it never flushes any SQL, but it's not a terrific solution, since if concurrent threads load the same objects at the same time, you may have multiple copies of the same objects present in collections. + +* How can I get the `Session` for a certain object ? + + Use the `object_session()` classmethod available on `Session`: + + {python} + session = Session.object_session(someobject) + +* Is the session threadsafe ? + + Nope. It has no thread synchronization of any kind built in, and particularly when you do a flush operation, it definitely is not open to concurrent threads accessing it, because it holds onto a single database connection at that point. If you use a session which is non-transactional for read operations only, it's still not thread-"safe", but you also wont get any catastrophic failures either, since it opens and closes connections on an as-needed basis; it's just that different threads might load the same objects independently of each other, but only one will wind up in the identity map (however, the other one might still live in a collection somewhere). + + But the bigger point here is, you should not *want* to use the session with multiple concurrent threads. That would be like having everyone at a restaurant all eat from the same plate. The session is a local "workspace" that you use for a specific set of tasks; you don't want to, or need to, share that session with other threads who are doing some other task. If, on the other hand, there are other threads participating in the same task you are, such as in a desktop graphical application, then you would be sharing the session with those threads, but you also will have implemented a proper locking scheme (or your graphical framework does) so that those threads do not collide. + +### Session Attributes {@name=attributes} + +The session provides a set of attributes and collection-oriented methods which allow you to view the current state of the session. + +The **identity map** is accessed by the `identity_map` attribute, which provides a dictionary interface. The keys are "identity keys", which are attached to all persistent objects by the attribute `_instance_key`: + + {python} + >>> myobject._instance_key + (, (7,)) + + >>> myobject._instance_key in session.identity_map + True + + >>> session.identity_map.values() + [<__main__.User object at 0x712630>, <__main__.Address object at 0x712a70>] + +The identity map is a weak-referencing dictionary by default. This means that objects which are dereferenced on the outside will be removed from the session automatically. Note that objects which are marked as "dirty" will not fall out of scope until after changes on them have been flushed; special logic kicks in at the point of auto-removal which ensures that no pending changes remain on the object, else a temporary strong reference is created to the object. + +Some people prefer objects to stay in the session until explicitly removed in all cases; for this, you can specify the flag `weak_identity_map=False` to the `create_session` or `sessionmaker` functions so that the `Session` will use a regular dictionary. + +While the `identity_map` accessor is currently the actual dictionary used by the `Session` to store instances, you should not add or remove items from this dictionary. Use the session methods `save_or_update()` and `expunge()` to add or remove items. + +The Session also supports an iterator interface in order to see all objects in the identity map: + + {python} + for obj in session: + print obj + +As well as `__contains__()`: + + {python} + if obj in session: + print "Object is present" + +The session is also keeping track of all newly created (i.e. pending) objects, all objects which have had changes since they were last loaded or saved (i.e. "dirty"), and everything that's been marked as deleted. + + {python} + # pending objects recently added to the Session + session.new + + # persistent objects which currently have changes detected + # (this collection is now created on the fly each time the property is called) + session.dirty + + # persistent objects that have been marked as deleted via session.delete(obj) + session.deleted + +### Querying + +The `query()` function takes one or more classes and/or mappers, along with an optional `entity_name` parameter, and returns a new `Query` object which will issue mapper queries within the context of this Session. For each mapper is passed, the Query uses that mapper. For each class, the Query will locate the primary mapper for the class using `class_mapper()`. + + {python} + # query from a class + session.query(User).filter_by(name='ed').all() + + # query with multiple classes, returns tuples + session.query(User).add_entity(Address).join('addresses').filter_by(name='ed').all() + + # query from a mapper + query = session.query(usermapper) + x = query.get(1) + + # query from a class mapped with entity name 'alt_users' + q = session.query(User, entity_name='alt_users') + y = q.options(eagerload('orders')).all() + +`entity_name` is an optional keyword argument sent with a class object, in order to further qualify which primary mapper to be used; this only applies if there was a `Mapper` created with that particular class/entity name combination, else an exception is raised. All of the methods on Session which take a class or mapper argument also take the `entity_name` argument, so that a given class can be properly matched to the desired primary mapper. + +All instances retrieved by the returned `Query` object will be stored as persistent instances within the originating `Session`. + +### Saving New Instances + +`save()` is called with a single transient instance as an argument, which is then added to the Session and becomes pending. When the session is next flushed, the instance will be saved to the database. If the given instance is not transient, meaning it is either attached to an existing Session or it has a database identity, an exception is raised. + + {python} + user1 = User(name='user1') + user2 = User(name='user2') + session.save(user1) + session.save(user2) + + session.commit() # write changes to the database + +There's also other ways to have objects saved to the session automatically; one is by using cascade rules, and the other is by using a contextual session. Both of these are described later. + +### Updating/Merging Existing Instances + +The `update()` method is used when you have a detached instance, and you want to put it back into a `Session`. Recall that "detached" means the object has a database identity. + +Since `update()` is a little picky that way, most people use `save_or_update()`, which checks for an `_instance_key` attribute, and based on whether it's there or not, calls either `save()` or `update()`: + + {python} + # load user1 using session 1 + user1 = sess1.query(User).get(5) + + # remove it from session 1 + sess1.expunge(user1) + + # move it into session 2 + sess2.save_or_update(user1) + +`update()` is also an operation that can happen automatically using cascade rules, just like `save()`. + +`merge()` on the other hand is a little like `update()`, except it creates a **copy** of the given instance in the session, and returns to you that instance; the instance you send it never goes into the session. `merge()` is much fancier than `update()`; it will actually look to see if an object with the same primary key is already present in the session, and if not will load it by primary key. Then, it will merge the attributes of the given object into the one which it just located. + +This method is useful for bringing in objects which may have been restored from a serialization, such as those stored in an HTTP session, where the object may be present in the session already: + + {python} + # deserialize an object + myobj = pickle.loads(mystring) + + # "merge" it. if the session already had this object in the + # identity map, then you get back the one from the current session. + myobj = session.merge(myobj) + +`merge()` includes an important option called `dont_load`. When this boolean flag is set to `True`, the merge of a detached object will not force a `get()` of that object from the database. Normally, `merge()` issues a `get()` for every existing object so that it can load the most recent state of the object, which is then modified according to the state of the given object. With `dont_load=True`, the `get()` is skipped and `merge()` places an exact copy of the given object in the session. This allows objects which were retrieved from a caching system to be copied back into a session without any SQL overhead being added. + +### Deleting + +The `delete` method places an instance into the Session's list of objects to be marked as deleted: + + {python} + # mark two objects to be deleted + session.delete(obj1) + session.delete(obj2) + + # commit (or flush) + session.commit() + +The big gotcha with `delete()` is that **nothing is removed from collections**. Such as, if a `User` has a collection of three `Addresses`, deleting an `Address` will not remove it from `user.addresses`: + + {python} + >>> address = user.addresses[1] + >>> session.delete(address) + >>> session.flush() + >>> address in user.addresses + True + +The solution is to use proper cascading: + + {python} + mapper(User, users_table, properties={ + 'addresses':relation(Address, cascade="all, delete") + }) + del user.addresses[1] + session.flush() + +### Flushing + +This is the main gateway to what the `Session` does best, which is save everything ! It should be clear by now what a flush looks like: + + {python} + session.flush() + +It also can be called with a list of objects; in this form, the flush operation will be limited only to the objects specified in the list: + + {python} + # saves only user1 and address2. all other modified + # objects remain present in the session. + session.flush([user1, address2]) + +This second form of flush should be used carefully as it will not necessarily locate other dependent objects within the session, whose database representation may have foreign constraint relationships with the objects being operated upon. + +Theres also a way to have `flush()` called automatically before each query; this is called "autoflush" and is described below. + +Note that when using a `Session` that has been placed into a transaction, the `commit()` method will also `flush()` the `Session` unconditionally before committing the transaction. + +Note that flush **does not change** the state of any collections or entity relationships in memory; for example, if you set a foreign key attribute `b_id` on object `A` with the identifier `B.id`, the change will be flushed to the database, but `A` will not have `B` added to its collection. If you want to manipulate foreign key attributes directly, `refresh()` or `expire()` the objects whose state needs to be refreshed subsequent to flushing. + +### Autoflush + +A session can be configured to issue `flush()` calls before each query. This allows you to immediately have DB access to whatever has been saved to the session. It's recommended to use autoflush with `transactional=True`, that way an unexpected flush call won't permanently save to the database: + + {python} + Session = sessionmaker(autoflush=True, transactional=True) + sess = Session() + u1 = User(name='jack') + sess.save(u1) + + # reload user1 + u2 = sess.query(User).filter_by(name='jack').one() + assert u2 is u1 + + # commit session, flushes whatever is remaining + sess.commit() + +Autoflush is particularly handy when using "dynamic" mapper relations, so that changes to the underlying collection are immediately available via its query interface. + +### Committing + +The `commit()` method on `Session` is used specifically when the `Session` is in a transactional state. The two ways that a session may be placed in a transactional state are to create it using the `transactional=True` option, or to call the `begin()` method. + +`commit()` serves **two** purposes; it issues a `flush()` unconditionally to persist any remaining pending changes, and it issues a commit to all currently managed database connections. In the typical case this is just a single connection. After the commit, connection resources which were allocated by the `Session` are released. This holds true even for a `Session` which specifies `transactional=True`; when such a session is committed, the next transaction is not "begun" until the next database operation occurs. + +See the section below on "Managing Transactions" for further detail. + +### Expunge / Clear + +Expunge removes an object from the Session, sending persistent instances to the detached state, and pending instances to the transient state: + + {python} + session.expunge(obj1) + +Use `expunge` when you'd like to remove an object altogether from memory, such as before calling `del` on it, which will prevent any "ghost" operations occurring when the session is flushed. + +This `clear()` method is equivalent to `expunge()`-ing everything from the Session: + + {python} + session.clear() + +However note that the `clear()` method does not reset any transactional state or connection resources; therefore what you usually want to call instead of `clear()` is `close()`. + +### Closing + +The `close()` method issues a `clear()`, and releases any transactional/connection resources. When connections are returned to the connection pool, whatever transactional state exists is rolled back. + +When `close()` is called, the `Session` is in the same state as when it was first created, and is safe to be used again. `close()` is especially important when using a contextual session, which remains in memory after usage. By issuing `close()`, the session will be clean for the next request that makes use of it. + +### Refreshing / Expiring + +To assist with the Session's "sticky" behavior of instances which are present, individual objects can have all of their attributes immediately re-loaded from the database, or marked as "expired" which will cause a re-load to occur upon the next access of any of the object's mapped attributes. This includes all relationships, so lazy-loaders will be re-initialized, eager relationships will be repopulated. Any changes marked on the object are discarded: + + {python} + # immediately re-load attributes on obj1, obj2 + session.refresh(obj1) + session.refresh(obj2) + + # expire objects obj1, obj2, attributes will be reloaded + # on the next access: + session.expire(obj1) + session.expire(obj2) + +`refresh()` and `expire()` also support being passed a list of individual attribute names in which to be refreshed. These names can reference any attribute, column-based or relation based: + + {python} + # immediately re-load the attributes 'hello', 'world' on obj1, obj2 + session.refresh(obj1, ['hello', 'world']) + session.refresh(obj2, ['hello', 'world']) + + # expire the attributes 'hello', 'world' objects obj1, obj2, attributes will be reloaded + # on the next access: + session.expire(obj1, ['hello', 'world']) + session.expire(obj2, ['hello', 'world']) + +## Cascades + +Mappers support the concept of configurable *cascade* behavior on `relation()`s. This behavior controls how the Session should treat the instances that have a parent-child relationship with another instance that is operated upon by the Session. Cascade is indicated as a comma-separated list of string keywords, with the possible values `all`, `delete`, `save-update`, `refresh-expire`, `merge`, `expunge`, and `delete-orphan`. + +Cascading is configured by setting the `cascade` keyword argument on a `relation()`: + + {python} + mapper(Order, order_table, properties={ + 'items' : relation(Item, items_table, cascade="all, delete-orphan"), + 'customer' : relation(User, users_table, user_orders_table, cascade="save-update"), + }) + +The above mapper specifies two relations, `items` and `customer`. The `items` relationship specifies "all, delete-orphan" as its `cascade` value, indicating that all `save`, `update`, `merge`, `expunge`, `refresh` `delete` and `expire` operations performed on a parent `Order` instance should also be performed on the child `Item` instances attached to it (`save` and `update` are cascaded using the `save_or_update()` method, so that the database identity of the instance doesn't matter). The `delete-orphan` cascade value additionally indicates that if an `Item` instance is no longer associated with an `Order`, it should also be deleted. The "all, delete-orphan" cascade argument allows a so-called *lifecycle* relationship between an `Order` and an `Item` object. + +The `customer` relationship specifies only the "save-update" cascade value, indicating most operations will not be cascaded from a parent `Order` instance to a child `User` instance, except for if the `Order` is attached with a particular session, either via the `save()`, `update()`, or `save-update()` method. + +Additionally, when a child item is attached to a parent item that specifies the "save-update" cascade value on the relationship, the child is automatically passed to `save_or_update()` (and the operation is further cascaded to the child item). + +Note that cascading doesn't do anything that isn't possible by manually calling Session methods on individual instances within a hierarchy, it merely automates common operations on a group of associated instances. + +The default value for `cascade` on `relation()`s is `save-update, merge`. + +## Managing Transactions + +The Session can manage transactions automatically, including across multiple engines. When the Session is in a transaction, as it receives requests to execute SQL statements, it adds each individual Connection/Engine encountered to its transactional state. At commit time, all unflushed data is flushed, and each individual transaction is committed. If the underlying databases support two-phase semantics, this may be used by the Session as well if two-phase transactions are enabled. + +The easiest way to use a Session with transactions is just to declare it as transactional. The session will remain in a transaction at all times: + + {python} + # transactional session + Session = sessionmaker(transactional=True) + sess = Session() + try: + item1 = sess.query(Item).get(1) + item2 = sess.query(Item).get(2) + item1.foo = 'bar' + item2.bar = 'foo' + + # commit- will immediately go into a new transaction afterwards + sess.commit() + except: + # rollback - will immediately go into a new transaction afterwards. + sess.rollback() + +Things to note above: + + * When using a transactional session, either a `rollback()` or a `close()` call **is required** when an error is raised by `flush()` or `commit()`. The `flush()` error condition will issue a ROLLBACK to the database automatically, but the state of the `Session` itself remains in an "undefined" state until the user decides whether to rollback or close. + * The `commit()` call unconditionally issues a `flush()`. Particularly when using `transactional=True` in conjunction with `autoflush=True`, explicit `flush()` calls are usually not needed. + +Alternatively, a transaction can be begun explicitly using `begin()`: + + {python} + # non transactional session + Session = sessionmaker(transactional=False) + sess = Session() + sess.begin() + try: + item1 = sess.query(Item).get(1) + item2 = sess.query(Item).get(2) + item1.foo = 'bar' + item2.bar = 'foo' + sess.commit() + except: + sess.rollback() + raise + +Like the `transactional` example, the same rules apply; an explicit `rollback()` or `close()` is required when an error occurs, and the `commit()` call issues a `flush()` as well. + +Session also supports Python 2.5's with statement so that the example above can be written as: + + {python} + Session = sessionmaker(transactional=False) + sess = Session() + with sess.begin(): + item1 = sess.query(Item).get(1) + item2 = sess.query(Item).get(2) + item1.foo = 'bar' + item2.bar = 'foo' + +Subtransactions can be created by calling the `begin()` method repeatedly. For each transaction you `begin()` you must always call either `commit()` or `rollback()`. Note that this includes the implicit transaction created by the transactional session. When a subtransaction is created the current transaction of the session is set to that transaction. Commiting the subtransaction will return you to the next outer transaction. Rolling it back will also return you to the next outer transaction, but in addition it will roll back database state to the innermost transaction that supports rolling back to. Usually this means the root transaction, unless you use the nested transaction functionality via the `begin_nested()` method. MySQL and Postgres (and soon Oracle) support using "nested" transactions by creating SAVEPOINTs, : + + {python} + Session = sessionmaker(transactional=False) + sess = Session() + sess.begin() + sess.save(u1) + sess.save(u2) + sess.flush() + + sess.begin_nested() # establish a savepoint + sess.save(u3) + sess.rollback() # rolls back u3, keeps u1 and u2 + + sess.commit() # commits u1 and u2 + +Finally, for MySQL, Postgres, and soon Oracle as well, the session can be instructed to use two-phase commit semantics. This will coordinate the commiting of transactions across databases so that the transaction is either committed or rolled back in all databases. You can also `prepare()` the session for interacting with transactions not managed by SQLAlchemy. To use two phase transactions set the flag `twophase=True` on the session: + + {python} + engine1 = create_engine('postgres://db1') + engine2 = create_engine('postgres://db2') + + Session = sessionmaker(twophase=True, transactional=True) + + # bind User operations to engine 1, Account operations to engine 2 + Session.configure(binds={User:engine1, Account:engine2}) + + sess = Session() + + # .... work with accounts and users + + # commit. session will issue a flush to all DBs, and a prepare step to all DBs, + # before committing both transactions + sess.commit() + +Be aware that when a crash occurs in one of the databases while the the transactions are prepared you have to manually commit or rollback the prepared transactions in your database as appropriate. + +## Embedding SQL Insert/Update Expressions into a Flush {@name=flushsql} + +This feature allows the value of a database column to be set to a SQL expression instead of a literal value. It's especially useful for atomic updates, calling stored procedures, etc. All you do is assign an expression to an attribute: + + {python} + class SomeClass(object): + pass + mapper(SomeClass, some_table) + + someobject = session.query(SomeClass).get(5) + + # set 'value' attribute to a SQL expression adding one + someobject.value = some_table.c.value + 1 + + # issues "UPDATE some_table SET value=value+1" + session.commit() + +This works both for INSERT and UPDATE statements. After the flush/commit operation, the `value` attribute on `someobject` gets "deferred", so that when you again access it the newly generated value will be loaded from the database. This is the same mechanism at work when database-side column defaults fire off. + +## Using SQL Expressions with Sessions {@name=sql} + +SQL constructs and string statements can be executed via the `Session`. You'd want to do this normally when your `Session` is transactional and you'd like your free-standing SQL statements to participate in the same transaction. + +The two ways to do this are to use the connection/execution services of the Session, or to have your Session participate in a regular SQL transaction. + +First, a Session thats associated with an Engine or Connection can execute statements immediately (whether or not it's transactional): + + {python} + Session = sessionmaker(bind=engine, transactional=True) + sess = Session() + result = sess.execute("select * from table where id=:id", {'id':7}) + result2 = sess.execute(select([mytable], mytable.c.id==7)) + +To get at the current connection used by the session, which will be part of the current transaction if one is in progress, use `connection()`: + + {python} + connection = sess.connection() + +A second scenario is that of a Session which is not directly bound to a connectable. This session executes statements relative to a particular `Mapper`, since the mappers are bound to tables which are in turn bound to connectables via their `MetaData` (either the session or the mapped tables need to be bound). In this case, the Session can conceivably be associated with multiple databases through different mappers; so it wants you to send along a `mapper` argument, which can be any mapped class or mapper instance: + + {python} + # session is *not* bound to an engine or connection + Session = sessionmaker(transactional=True) + sess = Session() + + # need to specify mapper or class when executing + result = sess.execute("select * from table where id=:id", {'id':7}, mapper=MyMappedClass) + result2 = sess.execute(select([mytable], mytable.c.id==7), mapper=MyMappedClass) + + # need to specify mapper or class when you call connection() + connection = sess.connection(MyMappedClass) + +The third scenario is when you are using `Connection` and `Transaction` yourself, and want the `Session` to participate. This is easy, as you just bind the `Session` to the connection: + + {python} + # non-transactional session + Session = sessionmaker(transactional=False) + + # non-ORM connection + transaction + conn = engine.connect() + trans = conn.begin() + + # bind the Session *instance* to the connection + sess = Session(bind=conn) + + # ... etc + + trans.commit() + +It's safe to use a `Session` which is transactional or autoflushing, as well as to call `begin()`/`commit()` on the session too; the outermost Transaction object, the one we declared explicitly, controls the scope of the transaction. + +When using the `threadlocal` engine context, things are that much easier; the `Session` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it: + + {python} + engine = create_engine('postgres://mydb', strategy="threadlocal") + engine.begin() + + sess = Session() # session takes place in the transaction like everyone else + + # ... go nuts + + engine.commit() # commit the transaction + +## Contextual/Thread-local Sessions {@name=contextual} + +A common need in applications, particularly those built around web frameworks, is the ability to "share" a `Session` object among disparate parts of an application, without needing to pass the object explicitly to all method and function calls. What you're really looking for is some kind of "global" session object, or at least "global" to all the parts of an application which are tasked with servicing the current request. For this pattern, SQLAlchemy provides the ability to enhance the `Session` class generated by `sessionmaker()` to provide auto-contextualizing support. This means that whenever you create a `Session` instance with its constructor, you get an *existing* `Session` object which is bound to some "context". By default, this context is the current thread. This feature is what previously was accomplished using the `sessioncontext` SQLAlchemy extension. + +### Creating a Thread-local Context {@name=creating} + +The `scoped_session()` function wraps around the `sessionmaker()` function, and produces an object which behaves the same as the `Session` subclass returned by `sessionmaker()`: + + {python} + from sqlalchemy.orm import scoped_session, sessionmaker + Session = scoped_session(sessionmaker(autoflush=True, transactional=True)) + +However, when you instantiate this `Session` "class", in reality the object is pulled from a threadlocal variable, or if it doesn't exist yet, it's created using the underlying class generated by `sessionmaker()`: + + {python} + >>> # call Session() the first time. the new Session instance is created. + >>> sess = Session() + + >>> # later, in the same application thread, someone else calls Session() + >>> sess2 = Session() + + >>> # the two Session objects are *the same* object + >>> sess is sess2 + True + +Since the `Session()` constructor now returns the same `Session` object every time within the current thread, the object returned by `scoped_session()` also implements most of the `Session` methods and properties at the "class" level, such that you don't even need to instantiate `Session()`: + + {python} + # create some objects + u1 = User() + u2 = User() + + # save to the contextual session, without instantiating + Session.save(u1) + Session.save(u2) + + # view the "new" attribute + assert u1 in Session.new + + # flush changes (if not using autoflush) + Session.flush() + + # commit transaction (if using a transactional session) + Session.commit() + +To "dispose" of the `Session`, there's two general approaches. One is to close out the current session, but to leave it assigned to the current context. This allows the same object to be re-used on another operation. This may be called from a current, instantiated `Session`: + + {python} + sess.close() + +Or, when using `scoped_session()`, the `close()` method may also be called as a classmethod on the `Session` "class": + + {python} + Session.close() + +When the `Session` is closed, it remains attached, but clears all of its contents and releases any ongoing transactional resources, including rolling back any remaining transactional state. The `Session` can then be used again. + +The other method is to remove the current session from the current context altogether. This is accomplished using the classmethod `remove()`: + + {python} + Session.remove() + +After `remove()` is called, the next call to `Session()` will create a *new* `Session` object which then becomes the contextual session. + +That, in a nutshell, is all there really is to it. Now for all the extra things one should know. + +### Lifespan of a Contextual Session {@name=lifespan} + +A (really, really) common question is when does the contextual session get created, when does it get disposed ? We'll consider a typical lifespan as used in a web application: + + {diagram} + Web Server Web Framework User-defined Controller Call + -------------- -------------- ------------------------------ + web request -> + call controller -> # call Session(). this establishes a new, + # contextual Session. + sess = Session() + + # load some objects, save some changes + objects = sess.query(MyClass).all() + + # some other code calls Session, it's the + # same contextual session as "sess" + sess2 = Session() + sess2.save(foo) + sess2.commit() + + # generate content to be returned + return generate_content() + Session.remove() <- + web response <- + +Above, we illustrate a *typical* organization of duties, where the "Web Framework" layer has some integration built-in to manage the span of ORM sessions. Upon the initial handling of an incoming web request, the framework passes control to a controller. The controller then calls `Session()` when it wishes to work with the ORM; this method establishes the contextual Session which will remain until it's removed. Disparate parts of the controller code may all call `Session()` and will get the same session object. Then, when the controller has completed and the response is to be sent to the web server, the framework **closes out** the current contextual session, above using the `remove()` method which removes the session from the context altogether. + +As an alternative, the "finalization" step can also call `Session.close()`, which will leave the same session object in place. Which one is better ? For a web framework which runs from a fixed pool of threads, it doesn't matter much. For a framework which runs a **variable** number of threads, or which **creates and disposes** of a thread for each request, `remove()` is better, since it leaves no resources associated with the thread which might not exist. + +* Why close out the session at all ? Why not just leave it going so the next request doesn't have to do as many queries ? + + There are some cases where you may actually want to do this. However, this is a special case where you are dealing with data which **does not change** very often, or you don't care about the "freshness" of the data. In reality, a single thread of a web server may, on a slow day, sit around for many minutes or even hours without being accessed. When it's next accessed, if data from the previous request still exists in the session, that data may be very stale indeed. So it's generally better to have an empty session at the start of a web request. + +### Associating Classes and Mappers with a Contextual Session {@name=associating} + +Another luxury we gain, when we've established a `Session()` that can be globally accessed, is the ability for mapped classes and objects to provide us with session-oriented functionality automatically. When using the `scoped_session()` function, we access this feature using the `mapper` attribute on the object in place of the normal `sqlalchemy.orm.mapper` function: + + {python} + # "contextual" mapper function + mapper = Session.mapper + + # use normally + mapper(User, users_table, properties={ + relation(Address) + }) + mapper(Address, addresses_table) + +When we use the contextual `mapper()` function, our `User` and `Address` now gain a new attribute `query`, which will create a `Query` object for us against the contextual session: + + {python} + wendy = User.query.filter_by(name='wendy').one() + +#### Auto-Save Behavior with Contextual Session's Mapper {@name=autosave} + +By default, when using Session.mapper, **new instances are saved into the contextual session automatically upon construction;** there is no longer a need to call `save()`: + + {python} + >>> newuser = User(name='ed') + >>> assert newuser in Session.new + True + +The auto-save functionality can cause problems, namely that any `flush()` which occurs before a newly constructed object is fully populated will result in that object being INSERTed without all of its attributes completed. As a `flush()` is more frequent when using sessions with `autoflush=True`, **the auto-save behavior can be disabled**, using the `save_on_init=False` flag: + + {python} + # "contextual" mapper function + mapper = Session.mapper + + # use normally, specify no save on init: + mapper(User, users_table, properties={ + relation(Address) + }, save_on_init=False) + mapper(Address, addresses_table, save_on_init=False) + + # objects now again require explicit "save" + >>> newuser = User(name='ed') + >>> assert newuser in Session.new + False + + >>> Session.save(newuser) + >>> assert newuser in Session.new + True + +The functionality of `Session.mapper` is an updated version of what used to be accomplished by the `assignmapper()` SQLAlchemy extension. + +[Generated docstrings for scoped_session()](rel:docstrings_sqlalchemy.orm_modfunc_scoped_session) + +## Partitioning Strategies + +this section is TODO + +### Vertical Partitioning + +Vertical partitioning places different kinds of objects, or different tables, across multiple databases. + + {python} + engine1 = create_engine('postgres://db1') + engine2 = create_engine('postgres://db2') + + Session = sessionmaker(twophase=True, transactional=True) + + # bind User operations to engine 1, Account operations to engine 2 + Session.configure(binds={User:engine1, Account:engine2}) + + sess = Session() + +### Horizontal Partitioning + +Horizontal partitioning partitions the rows of a single table (or a set of tables) across multiple databases. + +See the "sharding" example in [attribute_shard.py](http://www.sqlalchemy.org/trac/browser/sqlalchemy/trunk/examples/sharding/attribute_shard.py) + +## Extending Session + +Extending the session can be achieved through subclassing as well as through a simple extension class, which resembles the style of [advdatamapping_mapper_extending](rel:advdatamapping_mapper_extending) called [SessionExtension](rel:docstrings_sqlalchemy.orm.session_SessionExtension). See the docstrings for more information on this class' methods. + +Basic usage is similar to `MapperExtension`: + + {python} + class MySessionExtension(SessionExtension): + def before_commit(self, session): + print "before commit!" + + Session = sessionmaker(extension=MySessionExtension()) + +or with `create_session()`: + + {python} + sess = create_session(extension=MySessionExtension()) + +The same `SessionExtension` instance can be used with any number of sessions. diff --git a/doc/build/content/sqlconstruction.txt b/doc/build/content/sqlconstruction.txt deleted file mode 100644 index a672fb5cec..0000000000 --- a/doc/build/content/sqlconstruction.txt +++ /dev/null @@ -1,963 +0,0 @@ -Constructing SQL Queries via Python Expressions {@name=sql} -=============================================== - -*Note:* This section describes how to use SQLAlchemy to construct SQL queries and receive result sets. It does *not* cover the object relational mapping capabilities of SQLAlchemy; that is covered later on in [datamapping](rel:datamapping). However, both areas of functionality work similarly in how selection criterion is constructed, so if you are interested just in ORM, you should probably skim through basic [sql_whereclause](rel:sql_whereclause) construction before moving on. - -Once you have used the `sqlalchemy.schema` module to construct your tables and/or reflect them from the database, performing SQL queries using those table meta data objects is done via the `sqlalchemy.sql` package. This package defines a large set of classes, each of which represents a particular kind of lexical construct within a SQL query; all are descendants of the common base class `sqlalchemy.sql.ClauseElement`. A full query is represented via a structure of `ClauseElement`s. A set of reasonably intuitive creation functions is provided by the `sqlalchemy.sql` package to create these structures; these functions are described in the rest of this section. - -Executing a `ClauseElement` structure can be performed in two general ways. You can use an `Engine` or a `Connection` object's `execute()` method to which you pass the query structure; this is known as **explicit style**. Or, if the `ClauseElement` structure is built upon Table metadata which is bound to an `Engine` directly, you can simply call `execute()` on the structure itself, known as **implicit style**. In both cases, the execution returns a cursor-like object (more on that later). The same clause structure can be executed repeatedly. The `ClauseElement` is compiled into a string representation by an underlying `Compiler` object which is associated with the `Engine` via its `Dialect`. - -The examples below all include a dump of the generated SQL corresponding to the query object, as well as a dump of the statement's bind parameters. In all cases, bind parameters are shown as named parameters using the colon format (i.e. ':name'). When the statement is compiled into a database-specific version, the named-parameter statement and its bind values are converted to the proper paramstyle for that database automatically. - -For this section, we will mostly use the implcit style of execution, meaning the `Table` objects are associated with a bound instance of `MetaData`, and constructed `ClauseElement` objects support self-execution. Assume the following configuration: - - {python} - from sqlalchemy import * - metadata = MetaData('sqlite:///mydb.db', echo=True) - - # a table to store users - users = Table('users', metadata, - Column('user_id', Integer, primary_key = True), - Column('user_name', String(40)), - Column('password', String(80)) - ) - - # a table that stores mailing addresses associated with a specific user - addresses = Table('addresses', metadata, - Column('address_id', Integer, primary_key = True), - Column('user_id', Integer, ForeignKey("users.user_id")), - Column('street', String(100)), - Column('city', String(80)), - Column('state', String(2)), - Column('zip', String(10)) - ) - - # a table that stores keywords - keywords = Table('keywords', metadata, - Column('keyword_id', Integer, primary_key = True), - Column('name', VARCHAR(50)) - ) - - # a table that associates keywords with users - userkeywords = Table('userkeywords', metadata, - Column('user_id', INT, ForeignKey("users")), - Column('keyword_id', INT, ForeignKey("keywords")) - ) - -### Simple Select {@name=select} - -A select is done by constructing a `Select` object with the proper arguments [[api](rel:docstrings_sqlalchemy.sql_modfunc_select)], adding any extra arguments if desired, then calling its `execute()` method. - - {python title="Basic Select"} - from sqlalchemy import * - - # use the select() function defined in the sql package - s = select([users]) - - # or, call the select() method off of a Table object - s = users.select() - - # then, call execute on the Select object: - {sql}result = s.execute() - SELECT users.user_id, users.user_name, users.password FROM users - {} - - # the SQL text of any clause object can also be viewed via the str() call: - >>> str(s) - SELECT users.user_id, users.user_name, users.password FROM users - -#### Explicit Execution {@name=explicit} - -As mentioned above, `ClauseElement` structures can also be executed with a `Connection` object explicitly: - - {python} - engine = create_engine('sqlite:///myfile.db') - conn = engine.connect() - - {sql}result = conn.execute(users.select()) - SELECT users.user_id, users.user_name, users.password FROM users - {} - - conn.close() - -#### Binding ClauseElements to Engines {@name=binding} - -For queries that don't contain any "bound" tables, `ClauseElement`s that represent a fully executeable statement support an `bind` keyword parameter which can bind the object to an `Engine` or `Connection`, thereby allowing implicit execution: - - {python} - # select using a table - {sql}select([users], bind=myengine).execute() - SELECT users.user_id, users.user_name, users.password FROM users - {} - - # select a literal - {sql}select(["current_time"], bind=myengine).execute() - SELECT current_time - {} - - # select a function - {sql}select([func.now()], bind=db).execute() - SELECT now() - {} - -#### Getting Results {@name=resultproxy} - -The object returned by `execute()` is a `sqlalchemy.engine.ResultProxy` object, which acts much like a DBAPI `cursor` object in the context of a result set, except that the rows returned can address their columns by ordinal position, column name, or even column object: - - {python title="Using the ResultProxy"} - # select rows, get resulting ResultProxy object - {sql}result = users.select().execute() - SELECT users.user_id, users.user_name, users.password FROM users - {} - - # get one row - row = result.fetchone() - - # get the 'user_id' column via integer index: - user_id = row[0] - - # or column name - user_name = row['user_name'] - - # or column object - password = row[users.c.password] - - # or column accessor - password = row.password - - # ResultProxy object also supports fetchall() - rows = result.fetchall() - - # or get the underlying DBAPI cursor object - cursor = result.cursor - - # after an INSERT, return the last inserted primary key value - # returned as a list of primary key values for *one* row - # (a list since primary keys can be composite) - id = result.last_inserted_ids() - - # close the result. If the statement was implicitly executed - # (i.e. without an explicit Connection), this will - # return the underlying connection resources back to - # the connection pool. de-referencing the result - # will also have the same effect. if an explicit Connection was - # used, then close() just closes the underlying cursor object. - result.close() - -#### Using Column Labels {@name=labels} - -A common need when writing statements that reference multiple tables is to create labels for columns, thereby separating columns from different tables with the same name. The Select construct supports automatic generation of column labels via the `use_labels=True` parameter: - - {python title="use_labels Flag"} - {sql}c = select([users, addresses], - users.c.user_id==addresses.c.address_id, - use_labels=True).execute() - SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.password AS users_password, addresses.address_id AS addresses_address_id, - addresses.user_id AS addresses_user_id, addresses.street AS addresses_street, - addresses.city AS addresses_city, addresses.state AS addresses_state, - addresses.zip AS addresses_zip - FROM users, addresses - WHERE users.user_id = addresses.address_id - {} - -The table name part of the label is affected if you use a construct such as a table alias: - - {python title="use_labels with an Alias"} - person = users.alias('person') - {sql}c = select([person, addresses], - person.c.user_id==addresses.c.address_id, - use_labels=True).execute() - SELECT person.user_id AS person_user_id, person.user_name AS person_user_name, - person.password AS person_password, addresses.address_id AS addresses_address_id, - addresses.user_id AS addresses_user_id, addresses.street AS addresses_street, - addresses.city AS addresses_city, addresses.state AS addresses_state, - addresses.zip AS addresses_zip FROM users AS person, addresses - WHERE person.user_id = addresses.address_id - -Labels are also generated in such a way as to never go beyond 30 characters. Most databases support a limit on the length of symbols, such as Postgres, and particularly Oracle which has a rather short limit of 30: - - {python title="use_labels Generates Abbreviated Labels"} - long_named_table = users.alias('this_is_the_person_table') - {sql}c = select([long_named_table], use_labels=True).execute() - SELECT this_is_the_person_table.user_id AS this_is_the_person_table_b36c, - this_is_the_person_table.user_name AS this_is_the_person_table_f76a, - this_is_the_person_table.password AS this_is_the_person_table_1e7c - FROM users AS this_is_the_person_table - {} - -You can also specify custom labels on a per-column basis using the `label()` function: - - {python title="label() Function on Column"} - {sql}c = select([users.c.user_id.label('id'), - users.c.user_name.label('name')]).execute() - SELECT users.user_id AS id, users.user_name AS name - FROM users - {} - -#### Table/Column Specification {@name=columns} - -Calling `select` off a table automatically generates a column clause which includes all the table's columns, in the order they are specified in the source Table object. - -But in addition to selecting all the columns off a single table, any set of columns can be specified, as well as full tables, and any combination of the two: - - {python title="Specify Columns to Select"} - # individual columns - {sql}c = select([users.c.user_id, users.c.user_name]).execute() - SELECT users.user_id, users.user_name FROM users - {} - - # full tables - {sql}c = select([users, addresses]).execute() - SELECT users.user_id, users.user_name, users.password, - addresses.address_id, addresses.user_id, - addresses.street, addresses.city, addresses.state, addresses.zip - FROM users, addresses - {} - - # combinations - {sql}c = select([users, addresses.c.zip]).execute() - SELECT users.user_id, users.user_name, users.password, - addresses.zip FROM users, addresses - {} - -### WHERE Clause {@name=whereclause} - -The WHERE condition is the named keyword argument `whereclause`, or the second positional argument to the `select()` constructor and the first positional argument to the `select()` method of `Table`. - -WHERE conditions are constructed using column objects, literal values, and functions defined in the `sqlalchemy.sql` module. Column objects override the standard Python operators to provide clause compositional objects, which compile down to SQL operations: - - {python title="Basic WHERE Clause"} - {sql}c = users.select(users.c.user_id == 7).execute() - SELECT users.user_id, users.user_name, users.password, - FROM users WHERE users.user_id = :users_user_id - {'users_user_id': 7} - -Notice that the literal value "7" was broken out of the query and placed into a bind parameter. Databases such as Oracle must parse incoming SQL and create a "plan" when new queries are received, which is an expensive process. By using bind parameters, the same query with various literal values can have its plan compiled only once, and used repeatedly with less overhead. - -More where clauses: - - {python} - # another comparison operator - {sql}c = select([users], users.c.user_id>7).execute() - SELECT users.user_id, users.user_name, users.password, - FROM users WHERE users.user_id > :users_user_id - {'users_user_id': 7} - - # OR keyword - {sql}c = users.select(or_(users.c.user_name=='jack', users.c.user_name=='ed')).execute() - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_name = :users_user_name - OR users.user_name = :users_user_name_1 - {'users_user_name_1': 'ed', 'users_user_name': 'jack'} - - # AND keyword - {sql}c = users.select(and_(users.c.user_name=='jack', users.c.password=='dog')).execute() - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_name = :users_user_name - AND users.password = :users_password - {'users_user_name': 'jack', 'users_password': 'dog'} - - # NOT keyword - {sql}c = users.select(not_( - or_(users.c.user_name=='jack', users.c.password=='dog') - )).execute() - SELECT users.user_id, users.user_name, users.password - FROM users - WHERE NOT (users.user_name = :users_user_name - OR users.password = :users_password) - {'users_user_name': 'jack', 'users_password': 'dog'} - - # IN clause - {sql}c = users.select(users.c.user_name.in_('jack', 'ed', 'fred')).execute() - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_name - IN (:users_user_name, :users_user_name_1, :users_user_name_2) - {'users_user_name': 'jack', 'users_user_name_1': 'ed', - 'users_user_name_2': 'fred'} - - - # join users and addresses together - {sql}c = select([users, addresses], users.c.user_id==addresses.c.address_id).execute() - SELECT users.user_id, users.user_name, users.password, - addresses.address_id, addresses.user_id, addresses.street, addresses.city, - addresses.state, addresses.zip - FROM users, addresses - WHERE users.user_id = addresses.address_id - {} - - - # join users and addresses together, but dont specify "addresses" in the - # selection criterion. The WHERE criterion adds it to the FROM list - # automatically. - {sql}c = select([users], and_( - users.c.user_id==addresses.c.user_id, - users.c.user_name=='fred' - )).execute() - SELECT users.user_id, users.user_name, users.password - FROM users, addresses WHERE users.user_id = addresses.user_id - AND users.user_name = :users_user_name - {'users_user_name': 'fred'} - - -Select statements can also generate a WHERE clause based on the parameters you give it. If a given parameter, which matches the name of a column or its "label" (the combined tablename + "_" + column name), and does not already correspond to a bind parameter in the select object, it will be added as a comparison against that column. This is a shortcut to creating a full WHERE clause: - - {python} - # specify a match for the "user_name" column - {sql}c = users.select().execute(user_name='ed') - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_name = :users_user_name - {'users_user_name': 'ed'} - - # specify a full where clause for the "user_name" column, as well as a - # comparison for the "user_id" column - {sql}c = users.select(users.c.user_name=='ed').execute(user_id=10) - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_name = :users_user_name AND users.user_id = :users_user_id - {'users_user_name': 'ed', 'users_user_id': 10} - -#### Operators {@name=operators} - -Supported column operators so far are all the numerical comparison operators, i.e. '==', '>', '>=', etc., as well as `like()`, `startswith()`, `endswith()`, `between()`, and `in()`. Boolean operators include `not_()`, `and_()` and `or_()`, which also can be used inline via '~', '&', and '|'. Math operators are '+', '-', '*', '/'. Any custom operator can be specified via the `op()` function shown below. - - {python} - # "like" operator - users.select(users.c.user_name.like('%ter')) - - # equality operator - users.select(users.c.user_name == 'jane') - - # in opertator - users.select(users.c.user_id.in_(1,2,3)) - - # and_, endswith, equality operators - users.select(and_(addresses.c.street.endswith('green street'), - addresses.c.zip=='11234')) - - # & operator subsituting for 'and_' - users.select(addresses.c.street.endswith('green street') & (addresses.c.zip=='11234')) - - # + concatenation operator - select([users.c.user_name + '_name']) - - # NOT operator - users.select(~(addresses.c.street == 'Green Street')) - - # any custom operator - select([users.c.user_name.op('||')('_category')]) - - # "null" comparison via == (converts to IS) - {sql}users.select(users.c.user_name==None).execute() - SELECT users.user_id, users.user_name, users.password - FROM users - WHERE users.user_name IS NULL - - # or via explicit null() construct - {sql}users.select(users.c.user_name==null()).execute() - SELECT users.user_id, users.user_name, users.password - FROM users - WHERE users.user_name IS NULL - -#### Functions {@name=functions} - -Functions can be specified using the `func` keyword: - - {python} - {sql}select([func.count(users.c.user_id)]).execute() - SELECT count(users.user_id) FROM users - - {sql}users.select(func.substr(users.c.user_name, 1) == 'J').execute() - SELECT users.user_id, users.user_name, users.password FROM users - WHERE substr(users.user_name, :substr) = :substr_1 - {'substr_1': 'J', 'substr': 1} - -Functions also are callable as standalone values: - - {python} - # call the "now()" function - time = func.now(bind=myengine).scalar() - - # call myfunc(1,2,3) - myvalue = func.myfunc(1, 2, 3, bind=db).execute() - - # or call them off the engine - db.func.now().scalar() - -#### Literals {@name=literals} - -You can drop in a literal value anywhere there isnt a column to attach to via the `literal` keyword: - - {python} - {sql}select([literal('foo') + literal('bar'), users.c.user_name]).execute() - SELECT :literal + :literal_1, users.user_name - FROM users - {'literal_1': 'bar', 'literal': 'foo'} - - # literals have all the same comparison functions as columns - {sql}select([literal('foo') == literal('bar')], bind=myengine).scalar() - SELECT :literal = :literal_1 - {'literal_1': 'bar', 'literal': 'foo'} - -Literals also take an optional `type` parameter to give literals a type. This can sometimes be significant, for example when using the "+" operator with SQLite, the String type is detected and the operator is converted to "||": - - {python} - {sql}select([literal('foo', type=String) + 'bar'], bind=e).execute() - SELECT ? || ? - ['foo', 'bar'] - -#### Order By {@name=orderby} - -The ORDER BY clause of a select statement can be specified as individual columns to order by within an array specified via the `order_by` parameter, and optional usage of the asc() and desc() functions: - - {python} - # straight order by - {sql}c = users.select(order_by=[users.c.user_name]).execute() - SELECT users.user_id, users.user_name, users.password - FROM users ORDER BY users.user_name - - # descending/ascending order by on multiple columns - {sql}c = users.select( - users.c.user_name>'J', - order_by=[desc(users.c.user_id), asc(users.c.user_name)]).execute() - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_name > :users_user_name - ORDER BY users.user_id DESC, users.user_name ASC - {'users_user_name':'J'} - -#### DISTINCT, LIMIT and OFFSET {@name=options} - -These are specified as keyword arguments: - - {python} - {sql}c = select([users.c.user_name], distinct=True).execute() - SELECT DISTINCT users.user_name FROM users - - {sql}c = users.select(limit=10, offset=20).execute() - SELECT users.user_id, users.user_name, users.password FROM users LIMIT 10 OFFSET 20 - -The Oracle driver does not support LIMIT and OFFSET directly, but instead wraps the generated query into a subquery and uses the "rownum" variable to control the rows selected (this is somewhat experimental). Similarly, the Firebird and MSSQL drivers convert LIMIT into queries using FIRST and TOP, respectively. - -### Inner and Outer Joins {@name=join} - -As some of the examples indicated above, a regular inner join can be implicitly stated, just like in a SQL expression, by just specifying the tables to be joined as well as their join conditions: - - {python} - {sql}addresses.select(addresses.c.user_id==users.c.user_id).execute() - SELECT addresses.address_id, addresses.user_id, addresses.street, - addresses.city, addresses.state, addresses.zip FROM addresses, users - WHERE addresses.user_id = users.user_id - {} - -There is also an explicit join constructor, which can be embedded into a select query via the `from_obj` parameter of the select statement: - - {python} - {sql}addresses.select(from_obj=[ - addresses.join(users, addresses.c.user_id==users.c.user_id) - ]).execute() - SELECT addresses.address_id, addresses.user_id, addresses.street, addresses.city, - addresses.state, addresses.zip - FROM addresses JOIN users ON addresses.user_id = users.user_id - {} - -The join constructor can also be used by itself: - - {python} - {sql}join(users, addresses, users.c.user_id==addresses.c.user_id).select().execute() - SELECT users.user_id, users.user_name, users.password, - addresses.address_id, addresses.user_id, addresses.street, addresses.city, - addresses.state, addresses.zip - FROM addresses JOIN users ON addresses.user_id = users.user_id - {} - -The join criterion in a join() call is optional. If not specified, the condition will be derived from the foreign key relationships of the two tables. If no criterion can be constructed, an exception will be raised. - - {python} - {sql}join(users, addresses).select().execute() - SELECT users.user_id, users.user_name, users.password, - addresses.address_id, addresses.user_id, addresses.street, addresses.city, - addresses.state, addresses.zip - FROM addresses JOIN users ON addresses.user_id = users.user_id - {} - -Notice that this is the first example where the FROM criterion of the select statement is explicitly specified. In most cases, the FROM criterion is automatically determined from the columns requested as well as the WHERE clause. The `from_obj` keyword argument indicates a list of explicit FROM clauses to be used in the statement. - -A join can be created on its own using the `join` or `outerjoin` functions, or can be created off of an existing Table or other selectable unit via the `join` or `outerjoin` methods: - - {python} - {sql}outerjoin(users, addresses, - users.c.user_id==addresses.c.address_id).select().execute() - SELECT users.user_id, users.user_name, users.password, addresses.address_id, - addresses.user_id, addresses.street, addresses.city, addresses.state, addresses.zip - FROM users LEFT OUTER JOIN addresses ON users.user_id = addresses.address_id - {} - - {sql}users.select(keywords.c.name=='running', from_obj=[ - users.join( - userkeywords, userkeywords.c.user_id==users.c.user_id).join( - keywords, keywords.c.keyword_id==userkeywords.c.keyword_id) - ]).execute() - SELECT users.user_id, users.user_name, users.password FROM users - JOIN userkeywords ON userkeywords.user_id = users.user_id - JOIN keywords ON keywords.keyword_id = userkeywords.keyword_id - WHERE keywords.name = :keywords_name - {'keywords_name': 'running'} - -Joins also provide a keyword argument `fold_equivalents` on the `select()` function which allows the column list of the resulting select to be "folded" to the minimal list of columns, based on those columns that are known to be equivalent from the "onclause" of the join. This saves the effort of constructing column lists manually in conjunction with databases like Postgres which can be picky about "ambiguous columns". In this example, only the "users.user_id" column, but not the "addresses.user_id" column, shows up in the column clause of the resulting select: - - {python} - {sql}users.join(addresses).select(fold_equivalents=True).execute() - SELECT users.user_id, users.user_name, users.password, addresses.address_id, - addresses.street, addresses.city, addresses.state, addresses.zip - FROM users JOIN addresses ON users.user_id = addresses.address_id - {} - -The `fold_equivalents` argument will recursively apply to "chained" joins as well, i.e. `a.join(b).join(c)...`. - -### Table Aliases {@name=alias} - -Aliases are used primarily when you want to use the same table more than once as a FROM expression in a statement: - - {python} - address_b = addresses.alias('addressb') - {sql}# select users who have an address on Green street as well as Orange street - users.select(and_( - users.c.user_id==addresses.c.user_id, - addresses.c.street.like('%Green%'), - users.c.user_id==address_b.c.user_id, - address_b.c.street.like('%Orange%') - )).execute() - SELECT users.user_id, users.user_name, users.password - FROM users, addresses, addresses AS addressb - WHERE users.user_id = addresses.user_id - AND addresses.street LIKE :addresses_street - AND users.user_id = addressb.user_id - AND addressb.street LIKE :addressb_street - {'addressb_street': '%Orange%', 'addresses_street': '%Green%'} - -### Subqueries {@name=subqueries} - -SQLAlchemy allows the creation of select statements from not just Table objects, but from a whole class of objects that implement the `Selectable` interface. This includes Tables, Aliases, Joins and Selects. Therefore, if you have a Select, you can select from the Select: - - {python} - >>> s = users.select() - >>> str(s) - SELECT users.user_id, users.user_name, users.password FROM users - - {python} - >>> s = s.select() - >>> str(s) - SELECT user_id, user_name, password - FROM (SELECT users.user_id, users.user_name, users.password FROM users) - -Any Select, Join, or Alias object supports the same column accessors as a Table: - - {python} - >>> s = users.select() - >>> [c.key for c in s.columns] - ['user_id', 'user_name', 'password'] - -When you use `use_labels=True` in a Select object, the label version of the column names become the keys of the accessible columns. In effect you can create your own "view objects": - - {python} - s = select([users, addresses], users.c.user_id==addresses.c.user_id, use_labels=True) - {sql}select([ - s.c.users_user_name, s.c.addresses_street, s.c.addresses_zip - ], s.c.addresses_city=='San Francisco').execute() - SELECT users_user_name, addresses_street, addresses_zip - FROM (SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, - users.password AS users_password, addresses.address_id AS addresses_address_id, - addresses.user_id AS addresses_user_id, addresses.street AS addresses_street, - addresses.city AS addresses_city, addresses.state AS addresses_state, - addresses.zip AS addresses_zip - FROM users, addresses - WHERE users.user_id = addresses.user_id) - WHERE addresses_city = :addresses_city - {'addresses_city': 'San Francisco'} - -To specify a SELECT statement as one of the selectable units in a FROM clause, it usually should be given an alias. - - {python} - {sql}s = users.select().alias('u') - select([addresses, s]).execute() - SELECT addresses.address_id, addresses.user_id, addresses.street, addresses.city, - addresses.state, addresses.zip, u.user_id, u.user_name, u.password - FROM addresses, - (SELECT users.user_id, users.user_name, users.password FROM users) AS u - {} - -Select objects can be used in a WHERE condition, in operators such as IN: - - {python} - # select user ids for all users whos name starts with a "p" - s = select([users.c.user_id], users.c.user_name.like('p%')) - - # now select all addresses for those users - {sql}addresses.select(addresses.c.user_id.in_(s)).execute() - SELECT addresses.address_id, addresses.user_id, addresses.street, - addresses.city, addresses.state, addresses.zip - FROM addresses WHERE addresses.address_id IN - (SELECT users.user_id FROM users WHERE users.user_name LIKE :users_user_name) - {'users_user_name': 'p%'} - -The sql package supports embedding select statements into other select statements as the criterion in a WHERE condition, or as one of the "selectable" objects in the FROM list of the query. It does not at the moment directly support embedding a SELECT statement as one of the column criterion for a statement, although this can be achieved via direct text insertion, described later. - -#### Scalar Column Subqueries {@name=scalar} - -Subqueries can be used in the column clause of a select statement by specifying the `scalar=True` flag: - - {python} - {sql}select([table2.c.col1, table2.c.col2, - select([table1.c.col1], table1.c.col2==7, scalar=True)]) - SELECT table2.col1, table2.col2, - (SELECT table1.col1 AS col1 FROM table1 WHERE col2=:table1_col2) - FROM table2 - {'table1_col2': 7} - -### Correlated Subqueries {@name=correlated} - -When a select object is embedded inside of another select object, and both objects reference the same table, SQLAlchemy makes the assumption that the table should be correlated from the child query to the parent query. To disable this behavior, specify the flag `correlate=False` to the Select statement. - - {python} - # make an alias of a regular select. - s = select([addresses.c.street], addresses.c.user_id==users.c.user_id).alias('s') - >>> str(s) - SELECT addresses.street FROM addresses, users - WHERE addresses.user_id = users.user_id - - # now embed that select into another one. the "users" table is removed from - # the embedded query's FROM list and is instead correlated to the parent query - s2 = select([users, s.c.street]) - >>> str(s2) - SELECT users.user_id, users.user_name, users.password, s.street - FROM users, (SELECT addresses.street FROM addresses - WHERE addresses.user_id = users.user_id) s - -#### EXISTS Clauses {@name=exists} - -An EXISTS clause can function as a higher-scaling version of an IN clause, and is usually used in a correlated fashion: - - {python} - # find all users who have an address on Green street: - {sql}users.select( - exists( - [addresses.c.address_id], - and_( - addresses.c.user_id==users.c.user_id, - addresses.c.street.like('%Green%') - ) - ) - ) - SELECT users.user_id, users.user_name, users.password - FROM users WHERE EXISTS (SELECT addresses.address_id - FROM addresses WHERE addresses.user_id = users.user_id - AND addresses.street LIKE :addresses_street) - {'addresses_street': '%Green%'} - -### Unions {@name=unions} - -Unions come in two flavors, UNION and UNION ALL, which are available via module level functions or methods off a Selectable: - - {python} - {sql}union( - addresses.select(addresses.c.street=='123 Green Street'), - addresses.select(addresses.c.street=='44 Park Ave.'), - addresses.select(addresses.c.street=='3 Mill Road'), - order_by=[addresses.c.street] - ).execute() - SELECT addresses.address_id, addresses.user_id, addresses.street, - addresses.city, addresses.state, addresses.zip - FROM addresses WHERE addresses.street = :addresses_street - UNION - SELECT addresses.address_id, addresses.user_id, addresses.street, - addresses.city, addresses.state, addresses.zip - FROM addresses WHERE addresses.street = :addresses_street_1 - UNION - SELECT addresses.address_id, addresses.user_id, addresses.street, - addresses.city, addresses.state, addresses.zip - FROM addresses WHERE addresses.street = :addresses_street_2 - ORDER BY addresses.street - {'addresses_street_1': '44 Park Ave.', - 'addresses_street': '123 Green Street', - 'addresses_street_2': '3 Mill Road'} - - {sql}users.select( - users.c.user_id==7 - ).union_all( - users.select( - users.c.user_id==9 - ), - order_by=[users.c.user_id] # order_by is an argument to union_all() - ).execute() - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_id = :users_user_id - UNION ALL - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_id = :users_user_id_1 - ORDER BY users.user_id - {'users_user_id_1': 9, 'users_user_id': 7} - -### Custom Bind Parameters {@name=bindparams} - -Throughout all these examples, SQLAlchemy is busy creating bind parameters wherever literal expressions occur. You can also specify your own bind parameters with your own names, and use the same statement repeatedly. The bind parameters, shown here in the "named" format, will be converted to the appropriate named or positional style according to the database implementation being used. - - {python title="Custom Bind Params"} - s = users.select(users.c.user_name==bindparam('username')) - - # execute implicitly - {sql}s.execute(username='fred') - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_name = :username - {'username': 'fred'} - - # execute explicitly - conn = engine.connect() - {sql}conn.execute(s, username='fred') - SELECT users.user_id, users.user_name, users.password - FROM users WHERE users.user_name = :username - {'username': 'fred'} - - -`executemany()` is also available by supplying multiple dictionary arguments instead of keyword arguments to the `execute()` method of `ClauseElement` or `Connection`. Examples can be found later in the sections on INSERT/UPDATE/DELETE. - -#### Precompiling a Query {@name=precompiling} - -By throwing the `compile()` method onto the end of any query object, the query can be "compiled" by the Engine into a `sqlalchemy.sql.Compiled` object just once, and the resulting compiled object reused, which eliminates repeated internal compilation of the SQL string: - - {python}s = users.select(users.c.user_name==bindparam('username')).compile() - s.execute(username='fred') - s.execute(username='jane') - s.execute(username='mary') - -### Literal Text Blocks {@name=textual} - -The sql package tries to allow free textual placement in as many ways as possible. In the examples below, note that the from_obj parameter is used only when no other information exists within the select object with which to determine table metadata. Also note that in a query where there isnt even table metadata used, the Engine to be used for the query has to be explicitly specified: - - {python} - # strings as column clauses - {sql}select(["user_id", "user_name"], from_obj=[users]).execute() - SELECT user_id, user_name FROM users - {} - - # strings for full column lists - {sql}select( - ["user_id, user_name, password, addresses.*"], - from_obj=[users.alias('u'), addresses]).execute() - SELECT u.user_id, u.user_name, u.password, addresses.* - FROM users AS u, addresses - {} - - # functions, etc. - {sql}select([users.c.user_id, "process_string(user_name)"]).execute() - SELECT users.user_id, process_string(user_name) FROM users - {} - - # where clauses - {sql}users.select(and_(users.c.user_id==7, "process_string(user_name)=27")).execute() - SELECT users.user_id, users.user_name, users.password FROM users - WHERE users.user_id = :users_user_id AND process_string(user_name)=27 - {'users_user_id': 7} - - # subqueries - {sql}users.select( - "exists (select 1 from addresses where addresses.user_id=users.user_id)").execute() - SELECT users.user_id, users.user_name, users.password FROM users - WHERE exists (select 1 from addresses where addresses.user_id=users.user_id) - {} - - # custom FROM objects - {sql}select( - ["*"], - from_obj=["(select user_id, user_name from users)"], - bind=db).execute() - SELECT * FROM (select user_id, user_name from users) - {} - - # a full query - {sql}text("select user_name from users", bind=db).execute() - SELECT user_name FROM users - {} - - -#### Using Bind Parameters in Text Blocks {@name=textual_binds} - -Use the format `':paramname'` to define bind parameters inside of a text block. They will be converted to the appropriate format upon compilation: - - {python}t = text("select foo from mytable where lala=:hoho", bind=engine) - r = t.execute(hoho=7) - -Bind parameters can also be explicit, which allows typing information to be added. Just specify them as a list with keys that match those inside the textual statement: - - {python}t = text("select foo from mytable where lala=:hoho", - bindparams=[bindparam('hoho', type=types.String)], bind=engine) - r = t.execute(hoho="im hoho") - -Result-row type processing can be added via the `typemap` argument, which is a dictionary of return columns mapped to types: - - {python}# specify DateTime type for the 'foo' column in the result set - # sqlite, for example, uses result-row post-processing to construct dates - t = text("select foo from mytable where lala=:hoho", - bindparams=[bindparam('hoho', type=types.String)], - typemap={'foo':types.DateTime}, bind=engine - ) - r = t.execute(hoho="im hoho") - - # 'foo' is a datetime - year = r.fetchone()['foo'].year - -### Building Select Objects {@name=building} - -One of the primary motivations for a programmatic SQL library is to allow the piecemeal construction of a SQL statement based on program variables. All the above examples typically show Select objects being created all at once. The Select object also includes "builder" methods to allow building up an object. The below example is a "user search" function, where users can be selected based on primary key, user name, street address, keywords, or any combination: - - {python} - def find_users(id=None, name=None, street=None, keywords=None): - statement = users.select() - if id is not None: - statement.append_whereclause(users.c.user_id==id) - if name is not None: - statement.append_whereclause(users.c.user_name==name) - if street is not None: - # append_whereclause joins "WHERE" conditions together with AND - statement.append_whereclause(users.c.user_id==addresses.c.user_id) - statement.append_whereclause(addresses.c.street==street) - if keywords is not None: - statement.append_from( - users.join(userkeywords, users.c.user_id==userkeywords.c.user_id).join( - keywords, userkeywords.c.keyword_id==keywords.c.keyword_id)) - statement.append_whereclause(keywords.c.name.in_(keywords)) - # to avoid multiple repeats, set query to be DISTINCT: - statement.distinct=True - return statement.execute() - - {sql}find_users(id=7) - SELECT users.user_id, users.user_name, users.password - FROM users - WHERE users.user_id = :users_user_id - {'users_user_id': 7} - - {sql}find_users(street='123 Green Street') - SELECT users.user_id, users.user_name, users.password - FROM users, addresses - WHERE users.user_id = addresses.user_id AND addresses.street = :addresses_street - {'addresses_street': '123 Green Street'} - - {sql}find_users(name='Jack', keywords=['jack','foo']) - SELECT DISTINCT users.user_id, users.user_name, users.password - FROM users JOIN userkeywords ON users.user_id = userkeywords.user_id - JOIN keywords ON userkeywords.keyword_id = keywords.keyword_id - WHERE users.user_name = :users_user_name AND keywords.name IN ('jack', 'foo') - {'users_user_name': 'Jack'} - -### Inserts {@name=insert} - -An INSERT involves just one table. The Insert object is used via the insert() function, and the specified columns determine what columns show up in the generated SQL. If primary key columns are left out of the criterion, the SQL generator will try to populate them as specified by the particular database engine and sequences, i.e. relying upon an auto-incremented column or explicitly calling a sequence beforehand. Insert statements, as well as updates and deletes, can also execute multiple parameters in one pass via specifying an array of dictionaries as parameters. - -The values to be populated for an INSERT or an UPDATE can be specified to the insert()/update() functions as the `values` named argument, or the query will be compiled based on the values of the parameters sent to the execute() method. - - {python title="Using insert()"} - # basic insert - {sql}users.insert().execute(user_id=1, user_name='jack', password='asdfdaf') - INSERT INTO users (user_id, user_name, password) - VALUES (:user_id, :user_name, :password) - {'user_name': 'jack', 'password': 'asdfdaf', 'user_id': 1} - - # insert just user_name, NULL for others - # will auto-populate primary key columns if they are configured - # to do so - {sql}users.insert().execute(user_name='ed') - INSERT INTO users (user_name) VALUES (:user_name) - {'user_name': 'ed'} - - # INSERT with a list: - {sql}users.insert(values=(3, 'jane', 'sdfadfas')).execute() - INSERT INTO users (user_id, user_name, password) - VALUES (:user_id, :user_name, :password) - {'user_id': 3, 'password': 'sdfadfas', 'user_name': 'jane'} - - # INSERT with user-defined bind parameters - i = users.insert( - values={'user_name':bindparam('name'), 'password':bindparam('pw')} - ) - {sql}i.execute(name='mary', pw='adas5fs') - INSERT INTO users (user_name, password) VALUES (:name, :pw) - {'name': 'mary', 'pw': 'adas5fs'} - - # INSERT many - if no explicit 'values' parameter is sent, - # the first parameter list in the list determines - # the generated SQL of the insert (i.e. what columns are present) - # executemany() is used at the DBAPI level - {sql}users.insert().execute( - {'user_id':7, 'user_name':'jack', 'password':'asdfasdf'}, - {'user_id':8, 'user_name':'ed', 'password':'asdffcadf'}, - {'user_id':9, 'user_name':'fred', 'password':'asttf'}, - ) - INSERT INTO users (user_id, user_name, password) - VALUES (:user_id, :user_name, :password) - [{'user_name': 'jack', 'password': 'asdfasdf', 'user_id': 7}, - {'user_name': 'ed', 'password': 'asdffcadf', 'user_id': 8}, - {'user_name': 'fred', 'password': 'asttf', 'user_id': 9}] - -### Updates {@name=update} - -Updates work a lot like INSERTS, except there is an additional WHERE clause that can be specified. - - {python title="Using update()"} - # change 'jack' to 'ed' - {sql}users.update(users.c.user_name=='jack').execute(user_name='ed') - UPDATE users SET user_name=:user_name WHERE users.user_name = :users_user_name - {'users_user_name': 'jack', 'user_name': 'ed'} - - # use bind parameters - u = users.update(users.c.user_name==bindparam('name'), - values={'user_name':bindparam('newname')}) - {sql}u.execute(name='jack', newname='ed') - UPDATE users SET user_name=:newname WHERE users.user_name = :name - {'newname': 'ed', 'name': 'jack'} - - # update a column to another column - {sql}users.update(values={users.c.password:users.c.user_name}).execute() - UPDATE users SET password=users.user_name - {} - - # expressions OK too - {sql}users.update(values={users.c.user_id:users.c.user_id + 17}).execute() - UPDATE users SET user_id=users.user_id + :users_user_id - {'users_user_id':17} - - # multi-update - {sql}users.update(users.c.user_id==bindparam('id')).execute( - {'id':7, 'user_name':'jack', 'password':'fh5jks'}, - {'id':8, 'user_name':'ed', 'password':'fsr234ks'}, - {'id':9, 'user_name':'mary', 'password':'7h5jse'}, - ) - UPDATE users SET user_name=:user_name, password=:password WHERE users.user_id = :id - [{'password': 'fh5jks', 'user_name': 'jack', 'id': 7}, - {'password': 'fsr234ks', 'user_name': 'ed', 'id': 8}, - {'password': '7h5jse', 'user_name': 'mary', 'id': 9}] - -#### Correlated Updates {@name=correlated} - -A correlated update lets you update a table using selection from another table, or the same table: - - {python}s = select([addresses.c.city], addresses.c.user_id==users.c.user_id) - {sql}users.update( - and_(users.c.user_id>10, users.c.user_id<20), - values={users.c.user_name:s} - ).execute() - UPDATE users SET user_name=(SELECT addresses.city - FROM addresses - WHERE addresses.user_id = users.user_id) - WHERE users.user_id > :users_user_id AND users.user_id < :users_user_id_1 - {'users_user_id_1': 20, 'users_user_id': 10} - -### Deletes {@name=delete} - -A delete is formulated like an update, except theres no values: - - {python}users.delete(users.c.user_id==7).execute() - users.delete(users.c.user_name.like(bindparam('name'))).execute( - {'name':'%Jack%'}, - {'name':'%Ed%'}, - {'name':'%Jane%'}, - ) - users.delete(exists()) - diff --git a/doc/build/content/sqlexpression.txt b/doc/build/content/sqlexpression.txt new file mode 100644 index 0000000000..ec7e92c245 --- /dev/null +++ b/doc/build/content/sqlexpression.txt @@ -0,0 +1,933 @@ +SQL Expression Language Tutorial {@name=sql} +=============================================== + +This tutorial will cover SQLAlchemy SQL Expressions, which are Python constructs that represent SQL statements. The tutorial is in doctest format, meaning each `>>>` line represents something you can type at a Python command prompt, and the following text represents the expected return value. The tutorial has no prerequisites. + +## Version Check + +A quick check to verify that we are on at least **version 0.4** of SQLAlchemy: + + {python} + >>> import sqlalchemy + >>> sqlalchemy.__version__ # doctest:+SKIP + 0.4.0 + +## Connecting + +For this tutorial we will use an in-memory-only SQLite database. This is an easy way to test things without needing to have an actual database defined anywhere. To connect we use `create_engine()`: + + {python} + >>> from sqlalchemy import create_engine + >>> engine = create_engine('sqlite:///:memory:', echo=True) + +The `echo` flag is a shortcut to setting up SQLAlchemy logging, which is accomplished via Python's standard `logging` module. With it enabled, we'll see all the generated SQL produced. If you are working through this tutorial and want less output generated, set it to `False`. This tutorial will format the SQL behind a popup window so it doesn't get in our way; just click the "SQL" links to see whats being generated. + +## Define and Create Tables {@name=tables} + +The SQL Expression Language constructs its expressions in most cases against table columns. In SQLAlchemy, a column is most often represented by an object called `Column`, and in all cases a `Column` is associated with a `Table`. A collection of `Table` objects and their associated child objects is referred to as **database metadata**. In this tutorial we will explicitly lay out several `Table` objects, but note that SA can also "import" whole sets of `Table` objects automatically from an existing database (this process is called **table reflection**). + +We define our tables all within a catalog called `MetaData`, using the `Table` construct, which resembles regular SQL CREATE TABLE statements. We'll make two tables, one of which represents "users" in an application, and another which represents zero or more "email addreses" for each row in the "users" table: + + {python} + >>> from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey + >>> metadata = MetaData() + >>> users = Table('users', metadata, + ... Column('id', Integer, primary_key=True), + ... Column('name', String(40)), + ... Column('fullname', String(100)), + ... ) + + >>> addresses = Table('addresses', metadata, + ... Column('id', Integer, primary_key=True), + ... Column('user_id', None, ForeignKey('users.id')), + ... Column('email_address', String(50), nullable=False) + ... ) + +All about how to define `Table` objects, as well as how to create them from an existing database automatically, is described in [metadata](rel:metadata). + +Next, to tell the `MetaData` we'd actually like to create our selection of tables for real inside the SQLite database, we use `create_all()`, passing it the `engine` instance which points to our database. This will check for the presence of each table first before creating, so it's safe to call multiple times: + + {python} + {sql}>>> metadata.create_all(engine) #doctest: +NORMALIZE_WHITESPACE + PRAGMA table_info("users") + {} + PRAGMA table_info("addresses") + {} + CREATE TABLE users ( + id INTEGER NOT NULL, + name VARCHAR(40), + fullname VARCHAR(100), + PRIMARY KEY (id) + ) + {} + COMMIT + CREATE TABLE addresses ( + id INTEGER NOT NULL, + user_id INTEGER, + email_address VARCHAR(50) NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY(user_id) REFERENCES users (id) + ) + {} + COMMIT + +## Insert Expressions + +The first SQL expression we'll create is the `Insert` construct, which represents an INSERT statement. This is typically created relative to its target table: + + {python} + >>> ins = users.insert() + +To see a sample of the SQL this construct produces, use the `str()` function: + + {python} + >>> str(ins) + 'INSERT INTO users (id, name, fullname) VALUES (:id, :name, :fullname)' + +Notice above that the INSERT statement names every column in the `users` table. This can be limited by using the `values` keyword, which establishes the VALUES clause of the INSERT explicitly: + + {python} + >>> ins = users.insert(values={'name':'jack', 'fullname':'Jack Jones'}) + >>> str(ins) + 'INSERT INTO users (name, fullname) VALUES (:name, :fullname)' + +Above, while the `values` keyword limited the VALUES clause to just two columns, the actual data we placed in `values` didn't get rendered into the string; instead we got named bind parameters. As it turns out, our data *is* stored within our `Insert` construct, but it typically only comes out when the statement is actually executed; since the data consists of literal values, SQLAlchemy automatically generates bind parameters for them. We can peek at this data for now by looking at the compiled form of the statement: + + {python} + >>> ins.compile().params #doctest: +NORMALIZE_WHITESPACE + {'fullname': 'Jack Jones', 'name': 'jack'} + +## Executing {@name=executing} + +The interesting part of an `Insert` is executing it. In this tutorial, we will generally focus on the most explicit method of executing a SQL construct, and later touch upon some "shortcut" ways to do it. The `engine` object we created is a repository for database connections capable of issuing SQL to the database. To acquire a connection, we use the `connect()` method: + + {python} + >>> conn = engine.connect() + >>> conn #doctest: +ELLIPSIS + + +The `Connection` object represents an actively checked out DBAPI connection resource. Lets feed it our `Insert` object and see what happens: + + {python} + >>> result = conn.execute(ins) + {opensql}INSERT INTO users (name, fullname) VALUES (?, ?) + ['jack', 'Jack Jones'] + COMMIT + +So the INSERT statement was now issued to the database. Although we got positional "qmark" bind parameters instead of "named" bind parameters in the output. How come ? Because when executed, the `Connection` used the SQLite **dialect** to help generate the statement; when we use the `str()` function, the statement isn't aware of this dialect, and falls back onto a default which uses named parameters. We can view this manually as follows: + + {python} + >>> ins.bind = engine + >>> str(ins) + 'INSERT INTO users (name, fullname) VALUES (?, ?)' + +What about the `result` variable we got when we called `execute()` ? As the SQLAlchemy `Connection` object references a DBAPI connection, the result, known as a `ResultProxy` object, is analogous to the DBAPI cursor object. In the case of an INSERT, we can get important information from it, such as the primary key values which were generated from our statement: + + {python} + >>> result.last_inserted_ids() + [1] + +The value of `1` was automatically generated by SQLite, but only because we did not specify the `id` column in our `Insert` statement; otherwise, our explicit value would have been used. In either case, SQLAlchemy always knows how to get at a newly generated primary key value, even though the method of generating them is different across different databases; each databases' `Dialect` knows the specific steps needed to determine the correct value (or values; note that `last_inserted_ids()` returns a list so that it supports composite primary keys). + +## Executing Multiple Statements {@name=execmany} + +Our insert example above was intentionally a little drawn out to show some various behaviors of expression language constructs. In the usual case, an `Insert` statement is usually compiled against the parameters sent to the `execute()` method on `Connection`, so that there's no need to use the `values` keyword with `Insert`. Lets create a generic `Insert` statement again and use it in the "normal" way: + + {python} + >>> ins = users.insert() + >>> conn.execute(ins, id=2, name='wendy', fullname='Wendy Williams') # doctest: +ELLIPSIS + {opensql}INSERT INTO users (id, name, fullname) VALUES (?, ?, ?) + [2, 'wendy', 'Wendy Williams'] + COMMIT + {stop} + +Above, because we specified all three columns in the the `execute()` method, the compiled `Insert` included all three columns. The `Insert` statement is compiled at execution time based on the parameters we specified; if we specified fewer parameters, the `Insert` would have fewer entries in its VALUES clause. + +To issue many inserts using DBAPI's `executemany()` method, we can send in a list of dictionaries each containing a distinct set of parameters to be inserted, as we do here to add some email addresses: + + {python} + >>> conn.execute(addresses.insert(), [ # doctest: +ELLIPSIS + ... {'user_id': 1, 'email_address' : 'jack@yahoo.com'}, + ... {'user_id': 1, 'email_address' : 'jack@msn.com'}, + ... {'user_id': 2, 'email_address' : 'www@www.org'}, + ... {'user_id': 2, 'email_address' : 'wendy@aol.com'}, + ... ]) + {opensql}INSERT INTO addresses (user_id, email_address) VALUES (?, ?) + [[1, 'jack@yahoo.com'], [1, 'jack@msn.com'], [2, 'www@www.org'], [2, 'wendy@aol.com']] + COMMIT + {stop} + +Above, we again relied upon SQLite's automatic generation of primary key identifiers for each `addresses` row. + +When executing multiple sets of parameters, each dictionary must have the **same** set of keys; i.e. you cant have fewer keys in some dictionaries than others. This is because the `Insert` statement is compiled against the **first** dictionary in the list, and it's assumed that all subsequent argument dictionaries are compatible with that statement. + +## Connectionless / Implicit Execution {@name=connectionless} + +We're executing our `Insert` using a `Connection`. There's two options that allow you to not have to deal with the connection part. You can execute in the **connectionless** style, using the engine, which opens and closes a connection for you: + + {python} + {sql}>>> result = engine.execute(users.insert(), name='fred', fullname="Fred Flintstone") + INSERT INTO users (name, fullname) VALUES (?, ?) + ['fred', 'Fred Flintstone'] + COMMIT + +and you can save even more steps than that, if you connect the `Engine` to the `MetaData` object we created earlier. When this is done, all SQL expressions which involve tables within the `MetaData` object will be automatically **bound** to the `Engine`. In this case, we call it **implicit execution**: + + {python} + >>> metadata.bind = engine + {sql}>>> result = users.insert().execute(name="mary", fullname="Mary Contrary") + INSERT INTO users (name, fullname) VALUES (?, ?) + ['mary', 'Mary Contrary'] + COMMIT + +When the `MetaData` is bound, statements will also compile against the engine's dialect. Since a lot of the examples here assume the default dialect, we'll detach the engine from the metadata which we just attached: + + {python} + >>> metadata.bind = None + +Detailed examples of connectionless and implicit execution are available in the "Engines" chapter: [dbengine_implicit](rel:dbengine_implicit). + +## Selecting {@name=selecting} + +We began with inserts just so that our test database had some data in it. The more interesting part of the data is selecting it ! We'll cover UPDATE and DELETE statements later. The primary construct used to generate SELECT statements is the `select()` function: + + {python} + >>> from sqlalchemy.sql import select + >>> s = select([users]) + {opensql}>>> result = conn.execute(s) + SELECT users.id, users.name, users.fullname + FROM users + [] + +Above, we issued a basic `select()` call, placing the `users` table within the COLUMNS clause of the select, and then executing. SQLAlchemy expanded the `users` table into the set of each of its columns, and also generated a FROM clause for us. The result returned is again a `ResultProxy` object, which acts much like a DBAPI cursor, including methods such as `fetchone()` and `fetchall()`. The easiest way to get rows from it is to just iterate: + + {python} + >>> for row in result: + ... print row + (1, u'jack', u'Jack Jones') + (2, u'wendy', u'Wendy Williams') + (3, u'fred', u'Fred Flintstone') + (4, u'mary', u'Mary Contrary') + +Above, we see that printing each row produces a simple tuple-like result. We have more options at accessing the data in each row. One very common way is through dictionary access, using the string names of columns: + + {python} + {sql}>>> result = conn.execute(s) + SELECT users.id, users.name, users.fullname + FROM users + [] + + >>> row = result.fetchone() + >>> print "name:", row['name'], "; fullname:", row['fullname'] + name: jack ; fullname: Jack Jones + +Integer indexes work as well: + + {python} + >>> row = result.fetchone() + >>> print "name:", row[1], "; fullname:", row[2] + name: wendy ; fullname: Wendy Williams + +But another way, whose usefulness will become apparent later on, is to use the `Column` objects directly as keys: + + {python} + {sql}>>> for row in conn.execute(s): + ... print "name:", row[users.c.name], "; fullname:", row[users.c.fullname] + SELECT users.id, users.name, users.fullname + FROM users + [] + {stop}name: jack ; fullname: Jack Jones + name: wendy ; fullname: Wendy Williams + name: fred ; fullname: Fred Flintstone + name: mary ; fullname: Mary Contrary + +Result sets which have pending rows remaining should be explicitly closed before discarding. While the resources referenced by the `ResultProxy` will be closed when the object is garbage collected, it's better to make it explicit as some database APIs are very picky about such things: + + {python} + >>> result.close() + +If we'd like to more carefully control the columns which are placed in the COLUMNS clause of the select, we reference individual `Column` objects from our `Table`. These are available as named attributes off the `c` attribute of the `Table` object: + + {python} + >>> s = select([users.c.name, users.c.fullname]) + {sql}>>> result = conn.execute(s) + SELECT users.name, users.fullname + FROM users + [] + {stop}>>> for row in result: #doctest: +NORMALIZE_WHITESPACE + ... print row + (u'jack', u'Jack Jones') + (u'wendy', u'Wendy Williams') + (u'fred', u'Fred Flintstone') + (u'mary', u'Mary Contrary') + +Lets observe something interesting about the FROM clause. Whereas the generated statement contains two distinct sections, a "SELECT columns" part and a "FROM table" part, our `select()` construct only has a list containing columns. How does this work ? Let's try putting *two* tables into our `select()` statement: + + {python} + {sql}>>> for row in conn.execute(select([users, addresses])): + ... print row + SELECT users.id, users.name, users.fullname, addresses.id, addresses.user_id, addresses.email_address + FROM users, addresses + [] + {stop}(1, u'jack', u'Jack Jones', 1, 1, u'jack@yahoo.com') + (1, u'jack', u'Jack Jones', 2, 1, u'jack@msn.com') + (1, u'jack', u'Jack Jones', 3, 2, u'www@www.org') + (1, u'jack', u'Jack Jones', 4, 2, u'wendy@aol.com') + (2, u'wendy', u'Wendy Williams', 1, 1, u'jack@yahoo.com') + (2, u'wendy', u'Wendy Williams', 2, 1, u'jack@msn.com') + (2, u'wendy', u'Wendy Williams', 3, 2, u'www@www.org') + (2, u'wendy', u'Wendy Williams', 4, 2, u'wendy@aol.com') + (3, u'fred', u'Fred Flintstone', 1, 1, u'jack@yahoo.com') + (3, u'fred', u'Fred Flintstone', 2, 1, u'jack@msn.com') + (3, u'fred', u'Fred Flintstone', 3, 2, u'www@www.org') + (3, u'fred', u'Fred Flintstone', 4, 2, u'wendy@aol.com') + (4, u'mary', u'Mary Contrary', 1, 1, u'jack@yahoo.com') + (4, u'mary', u'Mary Contrary', 2, 1, u'jack@msn.com') + (4, u'mary', u'Mary Contrary', 3, 2, u'www@www.org') + (4, u'mary', u'Mary Contrary', 4, 2, u'wendy@aol.com') + +It placed **both** tables into the FROM clause. But also, it made a real mess. Those who are familiar with SQL joins know that this is a **Cartesian product**; each row from the `users` table is produced against each row from the `addresses` table. So to put some sanity into this statement, we need a WHERE clause. Which brings us to the second argument of `select()`: + + {python} + >>> s = select([users, addresses], users.c.id==addresses.c.user_id) + {sql}>>> for row in conn.execute(s): + ... print row + SELECT users.id, users.name, users.fullname, addresses.id, addresses.user_id, addresses.email_address + FROM users, addresses + WHERE users.id = addresses.user_id + [] + {stop}(1, u'jack', u'Jack Jones', 1, 1, u'jack@yahoo.com') + (1, u'jack', u'Jack Jones', 2, 1, u'jack@msn.com') + (2, u'wendy', u'Wendy Williams', 3, 2, u'www@www.org') + (2, u'wendy', u'Wendy Williams', 4, 2, u'wendy@aol.com') + +So that looks a lot better, we added an expression to our `select()` which had the effect of adding `WHERE users.id = addresses.user_id` to our statement, and our results were managed down so that the join of `users` and `addresses` rows made sense. But let's look at that expression? It's using just a Python equality operator between two different `Column` objects. It should be clear that something is up. Saying `1==1` produces `True`, and `1==2` produces `False`, not a WHERE clause. So lets see exactly what that expression is doing: + + {python} + >>> users.c.id==addresses.c.user_id #doctest: +ELLIPSIS + + +Wow, surprise ! This is neither a `True` nor a `False`. Well what is it ? + + {python} + >>> str(users.c.id==addresses.c.user_id) + 'users.id = addresses.user_id' + +As you can see, the `==` operator is producing an object that is very much like the `Insert` and `select()` objects we've made so far, thanks to Python's `__eq__()` builtin; you call `str()` on it and it produces SQL. By now, one can that everything we are working with is ultimately the same type of object. SQLAlchemy terms the base class of all of these expressions as `sqlalchemy.sql.ClauseElement`. + +## Operators {@name=operators} + +Since we've stumbled upon SQLAlchemy's operator paradigm, let's go through some of its capabilities. We've seen how to equate two columns to each other: + + {python} + >>> print users.c.id==addresses.c.user_id + users.id = addresses.user_id + +If we use a literal value (a literal meaning, not a SQLAlchemy clause object), we get a bind parameter: + + {python} + >>> print users.c.id==7 + users.id = :users_id_1 + +The `7` literal is embedded in `ClauseElement`; we can use the same trick we did with the `Insert` object to see it: + + {python} + >>> (users.c.id==7).compile().params + {'users_id_1': 7} + +Most Python operators, as it turns out, produce a SQL expression here, like equals, not equals, etc.: + + {python} + >>> print users.c.id != 7 + users.id != :users_id_1 + + >>> # None converts to IS NULL + >>> print users.c.name == None + users.name IS NULL + + >>> # reverse works too + >>> print 'fred' > users.c.name + users.name < :users_name_1 + +If we add two integer columns together, we get an addition expression: + + {python} + >>> print users.c.id + addresses.c.id + users.id + addresses.id + +Interestingly, the type of the `Column` is important ! If we use `+` with two string based columns (recall we put types like `Integer` and `String` on our `Column` objects at the beginning), we get something different: + + {python} + >>> print users.c.name + users.c.fullname + users.name || users.fullname + +Where `||` is the string concatenation operator used on most databases. But not all of them. MySQL users, fear not: + + {python} + >>> print (users.c.name + users.c.fullname).compile(bind=create_engine('mysql://')) + concat(users.name, users.fullname) + +The above illustrates the SQL that's generated for an `Engine` that's connected to a MySQL database; the `||` operator now compiles as MySQL's `concat()` function. + +If you have come across an operator which really isn't available, you can always use the `op()` method; this generates whatever operator you need: + + {python} + >>> print users.c.name.op('tiddlywinks')('foo') + users.name tiddlywinks :users_name_1 + +## Conjunctions {@name=conjunctions} + +We'd like to show off some of our operators inside of `select()` constructs. But we need to lump them together a little more, so let's first introduce some conjunctions. Conjunctions are those little words like AND and OR that put things together. We'll also hit upon NOT. AND, OR and NOT can work from the corresponding functions SQLAlchemy provides (notice we also throw in a LIKE): + + {python} + >>> from sqlalchemy.sql import and_, or_, not_ + >>> print and_(users.c.name.like('j%'), users.c.id==addresses.c.user_id, #doctest: +NORMALIZE_WHITESPACE + ... or_(addresses.c.email_address=='wendy@aol.com', addresses.c.email_address=='jack@yahoo.com'), + ... not_(users.c.id>5)) + users.name LIKE :users_name_1 AND users.id = addresses.user_id AND + (addresses.email_address = :addresses_email_address_1 OR addresses.email_address = :addresses_email_address_2) + AND users.id <= :users_id_1 + +And you can also use the re-jiggered bitwise AND, OR and NOT operators, although because of Python operator precedence you have to watch your parenthesis: + + {python} + >>> print users.c.name.like('j%') & (users.c.id==addresses.c.user_id) & \ + ... ((addresses.c.email_address=='wendy@aol.com') | (addresses.c.email_address=='jack@yahoo.com')) \ + ... & ~(users.c.id>5) # doctest: +NORMALIZE_WHITESPACE + users.name LIKE :users_name_1 AND users.id = addresses.user_id AND + (addresses.email_address = :addresses_email_address_1 OR addresses.email_address = :addresses_email_address_2) + AND users.id <= :users_id_1 + +So with all of this vocabulary, let's select all users who have an email address at AOL or MSN, whose name starts with a letter between "m" and "z", and we'll also generate a column containing their full name combined with their email address. We will add two new constructs to this statement, `between()` and `label()`. `between()` produces a BETWEEN clause, and `label()` is used in a column expression to produce labels using the `AS` keyword; it's recommended when selecting from expressions that otherwise would not have a name: + + {python} + >>> s = select([(users.c.fullname + ", " + addresses.c.email_address).label('title')], + ... and_( + ... users.c.id==addresses.c.user_id, + ... users.c.name.between('m', 'z'), + ... or_( + ... addresses.c.email_address.like('%@aol.com'), + ... addresses.c.email_address.like('%@msn.com') + ... ) + ... ) + ... ) + >>> print conn.execute(s).fetchall() #doctest: +NORMALIZE_WHITESPACE + SELECT users.fullname || ? || addresses.email_address AS title + FROM users, addresses + WHERE users.id = addresses.user_id AND users.name BETWEEN ? AND ? AND + (addresses.email_address LIKE ? OR addresses.email_address LIKE ?) + [', ', 'm', 'z', '%@aol.com', '%@msn.com'] + [(u'Wendy Williams, wendy@aol.com',)] + +Once again, SQLAlchemy figured out the FROM clause for our statement. In fact it will determine the FROM clause based on all of its other bits; the columns clause, the whereclause, and also some other elements which we haven't covered yet, which include ORDER BY, GROUP BY, and HAVING. + +## Using Text {@name=text} + +Our last example really became a handful to type. Going from what one understands to be a textual SQL expression into a Python construct which groups components together in a programmatic style can be hard. That's why SQLAlchemy lets you just use strings too. The `text()` construct represents any textual statement. To use bind parameters with `text()`, always use the named colon format. Such as below, we create a `text()` and execute it, feeding in the bind parameters to the `execute()` method: + + {python} + >>> from sqlalchemy.sql import text + >>> s = text("""SELECT users.fullname || ', ' || addresses.email_address AS title + ... FROM users, addresses + ... WHERE users.id = addresses.user_id AND users.name BETWEEN :x AND :y AND + ... (addresses.email_address LIKE :e1 OR addresses.email_address LIKE :e2) + ... """) + {sql}>>> print conn.execute(s, x='m', y='z', e1='%@aol.com', e2='%@msn.com').fetchall() # doctest:+NORMALIZE_WHITESPACE + SELECT users.fullname || ', ' || addresses.email_address AS title + FROM users, addresses + WHERE users.id = addresses.user_id AND users.name BETWEEN ? AND ? AND + (addresses.email_address LIKE ? OR addresses.email_address LIKE ?) + ['m', 'z', '%@aol.com', '%@msn.com'] + {stop}[(u'Wendy Williams, wendy@aol.com',)] + +To gain a "hybrid" approach, any of SA's SQL constructs can have text freely intermingled wherever you like - the `text()` construct can be placed within any other `ClauseElement` construct, and when used in a non-operator context, a direct string may be placed which converts to `text()` automatically. Below we combine the usage of `text()` and strings with our constructed `select()` object, by using the `select()` object to structure the statement, and the `text()`/strings to provide all the content within the structure. For this example, SQLAlchemy is not given any `Column` or `Table` objects in any of its expressions, so it cannot generate a FROM clause. So we also give it the `from_obj` keyword argument, which is a list of `ClauseElements` (or strings) to be placed within the FROM clause: + + {python} + >>> s = select([text("users.fullname || ', ' || addresses.email_address AS title")], + ... and_( + ... "users.id = addresses.user_id", + ... "users.name BETWEEN 'm' AND 'z'", + ... "(addresses.email_address LIKE :x OR addresses.email_address LIKE :y)" + ... ), + ... from_obj=['users', 'addresses'] + ... ) + {sql}>>> print conn.execute(s, x='%@aol.com', y='%@msn.com').fetchall() #doctest: +NORMALIZE_WHITESPACE + SELECT users.fullname || ', ' || addresses.email_address AS title + FROM users, addresses + WHERE users.id = addresses.user_id AND users.name BETWEEN 'm' AND 'z' AND (addresses.email_address LIKE ? OR addresses.email_address LIKE ?) + ['%@aol.com', '%@msn.com'] + {stop}[(u'Wendy Williams, wendy@aol.com',)] + +Going from constructed SQL to text, we lose some capabilities. We lose the capability for SQLAlchemy to compile our expression to a specific target database; above, our expression won't work with MySQL since it has no `||` construct. It also becomes more tedious for SQLAlchemy to be made aware of the datatypes in use; for example, if our bind parameters required UTF-8 encoding before going in, or conversion from a Python `datetime` into a string (as is required with SQLite), we would have to add extra information to our `text()` construct. Similar issues arise on the result set side, where SQLAlchemy also performs type-specific data conversion in some cases; still more information can be added to `text()` to work around this. But what we really lose from our statement is the ability to manipulate it, transform it, and analyze it. These features are critical when using the ORM, which makes heavy usage of relational transformations. To show off what we mean, we'll first introduce the ALIAS construct and the JOIN construct, just so we have some juicier bits to play with. + +## Using Aliases {@name=aliases} + +The alias corresponds to a "renamed" version of a table or arbitrary relation, which occurs anytime you say "SELECT .. FROM sometable AS someothername". The `AS` creates a new name for the table. Aliases are super important in SQL as they allow you to reference the same table more than once. Scenarios where you need to do this include when you self-join a table to itself, or more commonly when you need to join from a parent table to a child table multiple times. For example, we know that our user `jack` has two email addresses. How can we locate jack based on the combination of those two addresses? We need to join twice to it. Let's construct two distinct aliases for the `addresses` table and join: + + {python} + >>> a1 = addresses.alias('a1') + >>> a2 = addresses.alias('a2') + >>> s = select([users], and_( + ... users.c.id==a1.c.user_id, + ... users.c.id==a2.c.user_id, + ... a1.c.email_address=='jack@msn.com', + ... a2.c.email_address=='jack@yahoo.com' + ... )) + {sql}>>> print conn.execute(s).fetchall() + SELECT users.id, users.name, users.fullname + FROM users, addresses AS a1, addresses AS a2 + WHERE users.id = a1.user_id AND users.id = a2.user_id AND a1.email_address = ? AND a2.email_address = ? + ['jack@msn.com', 'jack@yahoo.com'] + {stop}[(1, u'jack', u'Jack Jones')] + +Easy enough. One thing that we're going for with the SQL Expression Language is the melding of programmatic behavior with SQL generation. Coming up with names like `a1` and `a2` is messy; we really didn't need to use those names anywhere, it's just the database that needed them. Plus, we might write some code that uses alias objects that came from several different places, and it's difficult to ensure that they all have unique names. So instead, we just let SQLAlchemy make the names for us, using "anonymous" aliases: + + {python} + >>> a1 = addresses.alias() + >>> a2 = addresses.alias() + >>> s = select([users], and_( + ... users.c.id==a1.c.user_id, + ... users.c.id==a2.c.user_id, + ... a1.c.email_address=='jack@msn.com', + ... a2.c.email_address=='jack@yahoo.com' + ... )) + {sql}>>> print conn.execute(s).fetchall() + SELECT users.id, users.name, users.fullname + FROM users, addresses AS addresses_1, addresses AS addresses_2 + WHERE users.id = addresses_1.user_id AND users.id = addresses_2.user_id AND addresses_1.email_address = ? AND addresses_2.email_address = ? + ['jack@msn.com', 'jack@yahoo.com'] + {stop}[(1, u'jack', u'Jack Jones')] + +One super-huge advantage of anonymous aliases is that not only did we not have to guess up a random name, but we can also be guaranteed that the above SQL string is **deterministically** generated to be the same every time. This is important for databases such as Oracle which cache compiled "query plans" for their statements, and need to see the same SQL string in order to make use of it. + +Aliases can of course be used for anything which you can SELECT from, including SELECT statements themselves. We can self-join the `users` table back to the `select()` we've created by making an alias of the entire statement. The `correlate(None)` directive is to avoid SQLAlchemy's attempt to "correlate" the inner `users` table with the outer one: + + {python} + >>> a1 = s.correlate(None).alias() + >>> s = select([users.c.name], users.c.id==a1.c.id) + {sql}>>> print conn.execute(s).fetchall() + SELECT users.name + FROM users, (SELECT users.id AS id, users.name AS name, users.fullname AS fullname + FROM users, addresses AS addresses_1, addresses AS addresses_2 + WHERE users.id = addresses_1.user_id AND users.id = addresses_2.user_id AND addresses_1.email_address = ? AND addresses_2.email_address = ?) AS anon_1 + WHERE users.id = anon_1.id + ['jack@msn.com', 'jack@yahoo.com'] + {stop}[(u'jack',)] + +## Using Joins {@name=joins} + +We're halfway along to being able to construct any SELECT expression. The next cornerstone of the SELECT is the JOIN expression. We've already been doing joins in our examples, by just placing two tables in either the columns clause or the where clause of the `select()` construct. But if we want to make a real "JOIN" or "OUTERJOIN" construct, we use the `join()` and `outerjoin()` methods, most commonly accessed from the left table in the join: + + {python} + >>> print users.join(addresses) + users JOIN addresses ON users.id = addresses.user_id + +The alert reader will see more surprises; SQLAlchemy figured out how to JOIN the two tables ! The ON condition of the join, as it's called, was automatically generated based on the `ForeignKey` object which we placed on the `addresses` table way at the beginning of this tutorial. Already the `join()` construct is looking like a much better way to join tables. + +Of course you can join on whatever expression you want, such as if we want to join on all users who use the same name in their email address as their username: + + {python} + >>> print users.join(addresses, addresses.c.email_address.like(users.c.name + '%')) + users JOIN addresses ON addresses.email_address LIKE users.name || :users_name_1 + +When we create a `select()` construct, SQLAlchemy looks around at the tables we've mentioned and then places them in the FROM clause of the statement. When we use JOINs however, we know what FROM clause we want, so here we make usage of the `from_obj` keyword argument: + + {python} + >>> s = select([users.c.fullname], from_obj=[ + ... users.join(addresses, addresses.c.email_address.like(users.c.name + '%')) + ... ]) + {sql}>>> print conn.execute(s).fetchall() + SELECT users.fullname + FROM users JOIN addresses ON addresses.email_address LIKE users.name || ? + ['%'] + {stop}[(u'Jack Jones',), (u'Jack Jones',), (u'Wendy Williams',)] + +The `outerjoin()` function just creates `LEFT OUTER JOIN` constructs. It's used just like `join()`: + + {python} + >>> s = select([users.c.fullname], from_obj=[users.outerjoin(addresses)]) + >>> print s + SELECT users.fullname + FROM users LEFT OUTER JOIN addresses ON users.id = addresses.user_id + +That's the output `outerjoin()` produces, unless, of course, you're stuck in a gig using Oracle prior to version 9, and you've set up your engine (which would be using `OracleDialect`) to use Oracle-specific SQL: + + {python} + >>> from sqlalchemy.databases.oracle import OracleDialect + >>> print s.compile(dialect=OracleDialect(use_ansi=False)) + SELECT users.fullname + FROM users, addresses + WHERE users.id = addresses.user_id(+) + +If you don't know what that SQL means, don't worry ! The secret tribe of Oracle DBAs don't want their black magic being found out ;). + +## Intro to Generative Selects and Transformations {@name=transform} + +We've now gained the ability to construct very sophisticated statements. We can use all kinds of operators, table constructs, text, joins, and aliases. The point of all of this, as mentioned earlier, is not that it's an "easier" or "better" way to write SQL than just writing a SQL statement yourself; the point is that it's better for writing *programmatically generated* SQL which can be morphed and adapted as needed in automated scenarios. + +To support this, the `select()` construct we've been working with supports piecemeal construction, in addition to the "all at once" method we've been doing. Suppose you're writing a search function, which receives criterion and then must construct a select from it. To accomplish this, upon each criterion encountered, you apply "generative" criterion to an existing `select()` construct with new elements, one at a time. We start with a basic `select()` constructed with the shortcut method available on the `users` table: + + {python} + >>> query = users.select() + >>> print query + SELECT users.id, users.name, users.fullname + FROM users + +We encounter search criterion of "name='jack'". So we apply WHERE criterion stating such: + + {python} + >>> query = query.where(users.c.name=='jack') + +Next, we encounter that they'd like the results in descending order by full name. We apply ORDER BY, using an extra modifier `desc`: + + {python} + >>> query = query.order_by(users.c.fullname.desc()) + +We also come across that they'd like only users who have an address at MSN. A quick way to tack this on is by using an EXISTS clause, which we correlate to the `users` table in the enclosing SELECT: + + {python} + >>> from sqlalchemy.sql import exists + >>> query = query.where( + ... exists([addresses.c.id], + ... and_(addresses.c.user_id==users.c.id, addresses.c.email_address.like('%@msn.com')) + ... ).correlate(users)) + +And finally, the application also wants to see the listing of email addresses at once; so to save queries, we outerjoin the `addresses` table (using an outer join so that users with no addresses come back as well; since we're programmatic, we might not have kept track that we used an EXISTS clause against the `addresses` table too...). Additionally, since the `users` and `addresses` table both have a column named `id`, let's isolate their names from each other in the COLUMNS clause by using labels: + + {python} + >>> query = query.column(addresses).select_from(users.outerjoin(addresses)).apply_labels() + +Let's bake for .0001 seconds and see what rises: + + {python} + {opensql}>>> conn.execute(query).fetchall() + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, addresses.email_address AS addresses_email_address + FROM users LEFT OUTER JOIN addresses ON users.id = addresses.user_id + WHERE users.name = ? AND (EXISTS (SELECT addresses.id + FROM addresses + WHERE addresses.user_id = users.id AND addresses.email_address LIKE ?)) ORDER BY users.fullname DESC + ['jack', '%@msn.com'] + {stop}[(1, u'jack', u'Jack Jones', 1, 1, u'jack@yahoo.com'), (1, u'jack', u'Jack Jones', 2, 1, u'jack@msn.com')] + +So we started small, added one little thing at a time, and at the end we have a huge statement..which actually works. Now let's do one more thing; the searching function wants to add another `email_address` criterion on, however it doesn't want to construct an alias of the `addresses` table; suppose many parts of the application are written to deal specifically with the `addresses` table, and to change all those functions to support receiving an arbitrary alias of the address would be cumbersome. We can actually *convert* the `addresses` table within the *existing* statement to be an alias of itself, using `replace_selectable()`: + + {python} + >>> a1 = addresses.alias() + >>> query = query.replace_selectable(addresses, a1) + >>> print query + {opensql}SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, addresses_1.id AS addresses_1_id, addresses_1.user_id AS addresses_1_user_id, addresses_1.email_address AS addresses_1_email_address + FROM users LEFT OUTER JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id + WHERE users.name = :users_name_1 AND (EXISTS (SELECT addresses_1.id + FROM addresses AS addresses_1 + WHERE addresses_1.user_id = users.id AND addresses_1.email_address LIKE :addresses_email_address_1)) ORDER BY users.fullname DESC + +One more thing though, with automatic labeling applied as well as anonymous aliasing, how do we retrieve the columns from the rows for this thing ? The label for the `email_addresses` column is now the generated name `addresses_1_email_address`; and in another statement might be something different ! This is where accessing by result columns by `Column` object becomes very useful: + + {python} + {sql}>>> for row in conn.execute(query): + ... print "Name:", row[users.c.name], "; Email Address", row[a1.c.email_address] + SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, addresses_1.id AS addresses_1_id, addresses_1.user_id AS addresses_1_user_id, addresses_1.email_address AS addresses_1_email_address + FROM users LEFT OUTER JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id + WHERE users.name = ? AND (EXISTS (SELECT addresses_1.id + FROM addresses AS addresses_1 + WHERE addresses_1.user_id = users.id AND addresses_1.email_address LIKE ?)) ORDER BY users.fullname DESC + ['jack', '%@msn.com'] + {stop}Name: jack ; Email Address jack@yahoo.com + Name: jack ; Email Address jack@msn.com + +The above example, by its end, got significantly more intense than the typical end-user constructed SQL will usually be. However when writing higher-level tools such as ORMs, they become more significant. SQLAlchemy's ORM relies very heavily on techniques like this. + +## Everything Else {@name=everythingelse} + +The concepts of creating SQL expressions have been introduced. What's left are more variants of the same themes. So now we'll catalog the rest of the important things we'll need to know. + +### Bind Parameter Objects {@name=bindparams} + +Throughout all these examples, SQLAlchemy is busy creating bind parameters wherever literal expressions occur. You can also specify your own bind parameters with your own names, and use the same statement repeatedly. The database dialect converts to the appropriate named or positional style, as here where it converts to positional for SQLite: + + {python} + >>> from sqlalchemy.sql import bindparam + >>> s = users.select(users.c.name==bindparam('username')) + {sql}>>> conn.execute(s, username='wendy').fetchall() + SELECT users.id, users.name, users.fullname + FROM users + WHERE users.name = ? + ['wendy'] + {stop}[(2, u'wendy', u'Wendy Williams')] + +Another important aspect of bind parameters is that they may be assigned a type. The type of the bind parameter will determine its behavior within expressions and also how the data bound to it is processed before being sent off to the database: + + {python} + >>> s = users.select(users.c.name.like(bindparam('username', type_=String) + text("'%'"))) + {sql}>>> conn.execute(s, username='wendy').fetchall() + SELECT users.id, users.name, users.fullname + FROM users + WHERE users.name LIKE ? || '%' + ['wendy'] + {stop}[(2, u'wendy', u'Wendy Williams')] + + +Bind parameters of the same name can also be used multiple times, where only a single named value is needed in the execute parameters: + + {python} + >>> s = select([users, addresses], + ... users.c.name.like(bindparam('name', type_=String) + text("'%'")) | + ... addresses.c.email_address.like(bindparam('name', type_=String) + text("'@%'")), + ... from_obj=[users.outerjoin(addresses)]) + {sql}>>> conn.execute(s, name='jack').fetchall() + SELECT users.id, users.name, users.fullname, addresses.id, addresses.user_id, addresses.email_address + FROM users LEFT OUTER JOIN addresses ON users.id = addresses.user_id + WHERE users.name LIKE ? || '%' OR addresses.email_address LIKE ? || '@%' + ['jack', 'jack'] + {stop}[(1, u'jack', u'Jack Jones', 1, 1, u'jack@yahoo.com'), (1, u'jack', u'Jack Jones', 2, 1, u'jack@msn.com')] + +### Functions {@name=functions} + +SQL functions are created using the `func` keyword, which generates functions using attribute access: + + {python} + >>> from sqlalchemy.sql import func + >>> print func.now() + now() + + >>> print func.concat('x', 'y') + concat(:param_1, :param_2) + +Certain functions are marked as "ANSI" functions, which mean they don't get the parenthesis added after them, such as CURRENT_TIMESTAMP: + + {python} + >>> print func.current_timestamp() + CURRENT_TIMESTAMP + +Functions are most typically used in the columns clause of a select statement, and can also be labeled as well as given a type. Labeling a function is recommended so that the result can be targeted in a result row based on a string name, and assigning it a type is required when you need result-set processing to occur, such as for Unicode conversion and date conversions. Below, we use the result function `scalar()` to just read the first column of the first row and then close the result; the label, even though present, is not important in this case: + + {python} + >>> print conn.execute( + ... select([func.max(addresses.c.email_address, type_=String).label('maxemail')]) + ... ).scalar() + {opensql}SELECT max(addresses.email_address) AS maxemail + FROM addresses + [] + {stop}www@www.org + +Databases such as Postgres and Oracle which support functions that return whole result sets can be assembled into selectable units, which can be used in statements. Such as, a database function `calculate()` which takes the parameters `x` and `y`, and returns three columns which we'd like to name `q`, `z` and `r`, we can construct using "lexical" column objects as well as bind parameters: + + {python} + >>> from sqlalchemy.sql import column + >>> calculate = select([column('q'), column('z'), column('r')], + ... from_obj=[func.calculate(bindparam('x'), bindparam('y'))]) + + >>> print select([users], users.c.id > calculate.c.z) + SELECT users.id, users.name, users.fullname + FROM users, (SELECT q, z, r + FROM calculate(:x, :y)) + WHERE users.id > z + +If we wanted to use our `calculate` statement twice with different bind parameters, the `unique_params()` function will create copies for us, and mark the bind parameters as "unique" so that conflicting names are isolated. Note we also make two separate aliases of our selectable: + + {python} + >>> s = select([users], users.c.id.between( + ... calculate.alias('c1').unique_params(x=17, y=45).c.z, + ... calculate.alias('c2').unique_params(x=5, y=12).c.z)) + + >>> print s + SELECT users.id, users.name, users.fullname + FROM users, (SELECT q, z, r + FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r + FROM calculate(:x_2, :y_2)) AS c2 + WHERE users.id BETWEEN c1.z AND c2.z + + >>> s.compile().params + {'x_2': 5, 'y_2': 12, 'y_1': 45, 'x_1': 17} + +### Unions and Other Set Operations {@name=unions} + +Unions come in two flavors, UNION and UNION ALL, which are available via module level functions: + + {python} + >>> from sqlalchemy.sql import union + >>> u = union( + ... addresses.select(addresses.c.email_address=='foo@bar.com'), + ... addresses.select(addresses.c.email_address.like('%@yahoo.com')), + ... ).order_by(addresses.c.email_address) + + {sql}>>> print conn.execute(u).fetchall() + SELECT addresses.id, addresses.user_id, addresses.email_address + FROM addresses + WHERE addresses.email_address = ? UNION SELECT addresses.id, addresses.user_id, addresses.email_address + FROM addresses + WHERE addresses.email_address LIKE ? ORDER BY addresses.email_address + ['foo@bar.com', '%@yahoo.com'] + {stop}[(1, 1, u'jack@yahoo.com')] + +Also available, though not supported on all databases, are `intersect()`, `intersect_all()`, `except_()`, and `except_all()`: + + {python} + >>> from sqlalchemy.sql import except_ + >>> u = except_( + ... addresses.select(addresses.c.email_address.like('%@%.com')), + ... addresses.select(addresses.c.email_address.like('%@msn.com')) + ... ) + + {sql}>>> print conn.execute(u).fetchall() + SELECT addresses.id, addresses.user_id, addresses.email_address + FROM addresses + WHERE addresses.email_address LIKE ? EXCEPT SELECT addresses.id, addresses.user_id, addresses.email_address + FROM addresses + WHERE addresses.email_address LIKE ? + ['%@%.com', '%@msn.com'] + {stop}[(1, 1, u'jack@yahoo.com'), (4, 2, u'wendy@aol.com')] + +### Scalar Selects {@name=scalar} + +To embed a SELECT in a column expression, use `as_scalar()`: + + {python} + {sql}>>> print conn.execute(select([ + ... users.c.name, + ... select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).as_scalar() + ... ])).fetchall() + SELECT users.name, (SELECT count(addresses.id) + FROM addresses + WHERE users.id = addresses.user_id) AS anon_1 + FROM users + [] + {stop}[(u'jack', 2), (u'wendy', 2), (u'fred', 0), (u'mary', 0)] + +Alternatively, applying a `label()` to a select evaluates it as a scalar as well: + + {python} + {sql}>>> print conn.execute(select([ + ... users.c.name, + ... select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).label('address_count') + ... ])).fetchall() + SELECT users.name, (SELECT count(addresses.id) + FROM addresses + WHERE users.id = addresses.user_id) AS address_count + FROM users + [] + {stop}[(u'jack', 2), (u'wendy', 2), (u'fred', 0), (u'mary', 0)] + +### Correlated Subqueries {@name=correlated} + +Notice in the examples on "scalar selects", the FROM clause of each embedded select did not contain the `users` table in its FROM clause. This is because SQLAlchemy automatically attempts to correlate embedded FROM objects to that of an enclosing query. To disable this, or to specify explicit FROM clauses to be correlated, use `correlate()`: + + {python} + >>> s = select([users.c.name], users.c.id==select([users.c.id]).correlate(None)) + >>> print s + SELECT users.name + FROM users + WHERE users.id = (SELECT users.id + FROM users) + + {python} + >>> s = select([users.c.name, addresses.c.email_address], users.c.id== + ... select([users.c.id], users.c.id==addresses.c.user_id).correlate(addresses) + ... ) + >>> print s + SELECT users.name, addresses.email_address + FROM users, addresses + WHERE users.id = (SELECT users.id + FROM users + WHERE users.id = addresses.user_id) + +### Ordering, Grouping, Limiting, Offset...ing... {@name=ordering} + +The `select()` function can take keyword arguments `order_by`, `group_by` (as well as `having`), `limit`, and `offset`. There's also `distinct=True`. These are all also available as generative functions. `order_by()` expressions can use the modifiers `asc()` or `desc()` to indicate ascending or descending. + + {python} + >>> s = select([addresses.c.user_id, func.count(addresses.c.id)]).\ + ... group_by(addresses.c.user_id).having(func.count(addresses.c.id)>1) + {opensql}>>> print conn.execute(s).fetchall() + SELECT addresses.user_id, count(addresses.id) + FROM addresses GROUP BY addresses.user_id + HAVING count(addresses.id) > ? + [1] + {stop}[(1, 2), (2, 2)] + + >>> s = select([addresses.c.email_address, addresses.c.id]).distinct().\ + ... order_by(addresses.c.email_address.desc(), addresses.c.id) + {opensql}>>> conn.execute(s).fetchall() + SELECT DISTINCT addresses.email_address, addresses.id + FROM addresses ORDER BY addresses.email_address DESC, addresses.id + [] + {stop}[(u'www@www.org', 3), (u'wendy@aol.com', 4), (u'jack@yahoo.com', 1), (u'jack@msn.com', 2)] + + >>> s = select([addresses]).offset(1).limit(1) + {opensql}>>> print conn.execute(s).fetchall() # doctest: +NORMALIZE_WHITESPACE + SELECT addresses.id, addresses.user_id, addresses.email_address + FROM addresses + LIMIT 1 OFFSET 1 + [] + {stop}[(2, 1, u'jack@msn.com')] + +## Updates {@name=update} + +Finally, we're back to UPDATE. Updates work a lot like INSERTS, except there is an additional WHERE clause that can be specified. + + {python} + >>> # change 'jack' to 'ed' + {sql}>>> conn.execute(users.update(users.c.name=='jack'), name='ed') #doctest: +ELLIPSIS + UPDATE users SET name=? WHERE users.name = ? + ['ed', 'jack'] + COMMIT + {stop} + + >>> # use bind parameters + >>> u = users.update(users.c.name==bindparam('oldname'), values={'name':bindparam('newname')}) + {sql}>>> conn.execute(u, oldname='jack', newname='ed') #doctest: +ELLIPSIS + UPDATE users SET name=? WHERE users.name = ? + ['ed', 'jack'] + COMMIT + {stop} + + >>> # update a column to an expression + {sql}>>> conn.execute(users.update(values={users.c.fullname:"Fullname: " + users.c.name})) #doctest: +ELLIPSIS + UPDATE users SET fullname=(? || users.name) + ['Fullname: '] + COMMIT + {stop} + +### Correlated Updates {@name=correlated} + +A correlated update lets you update a table using selection from another table, or the same table: + + {python} + >>> s = select([addresses.c.email_address], addresses.c.user_id==users.c.id).limit(1) + {sql}>>> conn.execute(users.update(values={users.c.fullname:s})) #doctest: +ELLIPSIS,+NORMALIZE_WHITESPACE + UPDATE users SET fullname=(SELECT addresses.email_address + FROM addresses + WHERE addresses.user_id = users.id + LIMIT 1 OFFSET 0) + [] + COMMIT + {stop} + +## Deletes {@name=delete} + +Finally, a delete. Easy enough: + + {python} + {sql}>>> conn.execute(addresses.delete()) #doctest: +ELLIPSIS + DELETE FROM addresses + [] + COMMIT + {stop} + + {sql}>>> conn.execute(users.delete(users.c.name > 'm')) #doctest: +ELLIPSIS + DELETE FROM users WHERE users.name > ? + ['m'] + COMMIT + {stop} + +## Further Reference {@name=reference} + +The best place to get every possible name you can use in constructed SQL is the [Generated Documentation](rel:docstrings_sqlalchemy.sql.expression). + +Table Metadata Reference: [metadata](rel:metadata) + +Engine/Connection/Execution Reference: [dbengine](rel:dbengine) + +SQL Types: [types](rel:types) + + diff --git a/doc/build/content/threadlocal.txt b/doc/build/content/threadlocal.txt deleted file mode 100644 index 20ad270aef..0000000000 --- a/doc/build/content/threadlocal.txt +++ /dev/null @@ -1,2 +0,0 @@ -The threadlocal mod {@name=threadlocal} -============ diff --git a/doc/build/content/tutorial.txt b/doc/build/content/tutorial.txt index d2077043df..3483a5be29 100644 --- a/doc/build/content/tutorial.txt +++ b/doc/build/content/tutorial.txt @@ -43,6 +43,14 @@ Note that the SQLite library download is not required with Windows, as the Windo Getting Started {@name=gettingstarted} -------------------------- +### Checking the Version + +**Note: This tutorial is oriented towards version 0.4 of SQLAlchemy. ** Check the version of SQLAlchemy you have installed via: + + {python} + >>> import sqlalchemy + >>> sqlalchemy.__version__ # doctest: +SKIP + 0.4.0 ### Imports @@ -51,7 +59,7 @@ To start connecting to databases and begin issuing queries, we want to import th {python} >>> from sqlalchemy import * -Note that importing using the `*` operator pulls all the names from `sqlalchemy` into the local module namespace, which in a real application can produce name conflicts. Therefore its recommended in practice to either import the individual symbols desired (i.e. `from sqlalchemy import Table, Column`) or to import under a distinct namespace (i.e. `import sqlalchemy as sa`). +Note that importing using the `*` operator pulls all the names from `sqlalchemy` into the local module namespace, which in a real application can produce name conflicts. Therefore it's recommended in practice to either import the individual symbols desired (i.e. `from sqlalchemy import Table, Column`) or to import under a distinct namespace (i.e. `import sqlalchemy as sa`). ### Connecting to the Database @@ -60,7 +68,7 @@ After our imports, the next thing we need is a handle to the desired database, r {python} >>> db = create_engine('sqlite:///tutorial.db') -Technically, the above statement did not make an actual connection to the sqlite database just yet. As soon as we begine working with the engine, it will start creating connections. In the case of SQLite, the `tutorial.db` file will actually be created at the moment it is first used, if the file does not exist already. +Technically, the above statement did not make an actual connection to the SQLite database just yet. As soon as we begin working with the engine, it will start creating connections. In the case of SQLite, the `tutorial.db` file will actually be created at the moment it is first used, if the file does not exist already. For full information on creating database engines, including those for SQLite and others, see [dbengine](rel:dbengine). @@ -73,7 +81,7 @@ A central concept of SQLAlchemy is that it actually contains two distinct areas The Object Relational Mapper (ORM) is a set of tools completely distinct from the SQL Construction Language which serve the purpose of mapping Python object instances into database rows, providing a rich selection interface with which to retrieve instances from tables as well as a comprehensive solution to persisting changes on those instances back into the database. When working with the ORM, its underlying workings as well as its public API make extensive use of the SQL Construction Language, however the general theory of operation is slightly different. Instead of working with database rows directly, you work with your own user-defined classes and object instances. Additionally, the method of issuing queries to the database is different, as the ORM handles the job of generating most of the SQL required, and instead requires more information about what kind of class instances you'd like to load and where you'd like to put them. -Where SA is somewhat unique, more powerful, and slightly more complicated is that the two areas of functionality can be mixed together in many ways. A key strategy to working with SA effectively is to have a solid awareness of these two distinct toolsets, and which concepts of SA belong to each - even some publications have confused the SQL Construction Language with the ORM. The key difference between the two is that when you're working with cursor-like result sets its the SQL Construction Language, and when working with collections of your own class instances its the Object Relational Mapper. +Where SA is somewhat unique, more powerful, and slightly more complicated is that the two areas of functionality can be mixed together in many ways. A key strategy to working with SA effectively is to have a solid awareness of these two distinct toolsets, and which concepts of SA belong to each - even some publications have confused the SQL Construction Language with the ORM. The key difference between the two is that when you're working with cursor-like result sets it's the SQL Construction Language, and when working with collections of your own class instances it's the Object Relational Mapper. This tutorial will first focus on the basic configuration that is common to using both the SQL Construction Language as well as the ORM, which is to declare information about your database called **table metadata**. This will be followed by some constructed SQL examples, and then into usage of the ORM utilizing the same data we established in the SQL construction examples. @@ -290,7 +298,7 @@ To start, we will import the names necessary to use SQLAlchemy's ORM, again usin {python} >>> from sqlalchemy.orm import * -It should be noted that the above step is technically not needed when working with the 0.3 series of SQLAlchemy; all symbols from the `orm` package are also included in the `sqlalchemy` package. However, a future release (most likely the 0.4 series) will make the separate `orm` import required in order to use the object relational mapper, so its a good practice for now. +It should be noted that the above step is technically not needed when working with the 0.3 series of SQLAlchemy; all symbols from the `orm` package are also included in the `sqlalchemy` package. However, a future release (most likely the 0.4 series) will make the separate `orm` import required in order to use the object relational mapper, so it's a good practice for now. ### Creating a Mapper {@name=mapper} @@ -360,7 +368,7 @@ Notice that our `User` class has a special attribute `c` attached to it. This ' ### Making Changes {@name=changes} -With a little experience in loading objects, lets see what its like to make changes. First, lets create a new user "Ed". We do this by just constructing the new object. Then, we just add it to the session: +With a little experience in loading objects, lets see what it's like to make changes. First, lets create a new user "Ed". We do this by just constructing the new object. Then, we just add it to the session: {python} >>> ed = User() @@ -452,7 +460,7 @@ The `relation()` function takes either a class or a Mapper as its first argument The order in which the mapping definitions for `User` and `Address` is created is *not significant*. When the `mapper()` function is called, it creates an *uncompiled* mapping record corresponding to the given class/table combination. When the mappers are first used, the entire collection of mappers created up until that point will be compiled, which involves the establishment of class instrumentation as well as the resolution of all mapping relationships. -Lets try out this new mapping configuration, and see what we get for the email addresses already in the database. Since we have made a new mapping configuration, its best that we clear out our `Session`, which is currently holding onto every `User` object we have already loaded: +Lets try out this new mapping configuration, and see what we get for the email addresses already in the database. Since we have made a new mapping configuration, it's best that we clear out our `Session`, which is currently holding onto every `User` object we have already loaded: {python} >>> session.clear() @@ -489,7 +497,7 @@ Main documentation for using mappers: [datamapping](rel:datamapping) ### Transactions -You may have noticed from the example above that when we say `session.flush()`, SQLAlchemy indicates the names `BEGIN` and `COMMIT` to indicate a transaction with the database. The `flush()` method, since it may execute many statements in a row, will automatically use a transaction in order to execute these instructions. But what if we want to use `flush()` inside of a larger transaction? The easiest way is to use a "transactional" session; that is, when the session is created, you're automatically in a transaction which you can commit or rollback at any time. As a bonus, it offers the ability to call `flush()` for you, whenever a query is issued; that way whatever changes you've made can be returned right back (and since its all in a transaction, nothing gets committed until you tell it so). +You may have noticed from the example above that when we say `session.flush()`, SQLAlchemy indicates the names `BEGIN` and `COMMIT` to indicate a transaction with the database. The `flush()` method, since it may execute many statements in a row, will automatically use a transaction in order to execute these instructions. But what if we want to use `flush()` inside of a larger transaction? The easiest way is to use a "transactional" session; that is, when the session is created, you're automatically in a transaction which you can commit or rollback at any time. As a bonus, it offers the ability to call `flush()` for you, whenever a query is issued; that way whatever changes you've made can be returned right back (and since it's all in a transaction, nothing gets committed until you tell it so). Below, we create a session with `autoflush=True`, which implies that it's transactional. We can query for things as soon as they are created without the need for calling `flush()`. At the end, we call `commit()` to persist everything permanently. @@ -497,7 +505,7 @@ Below, we create a session with `autoflush=True`, which implies that it's transa >>> metadata.bind.echo = False >>> session = create_session(autoflush=True) >>> (ed, harry, mary) = session.query(User).filter( - ... User.c.user_name.in_('Ed', 'Harry', 'Mary') + ... User.c.user_name.in_(['Ed', 'Harry', 'Mary']) ... ).order_by(User.c.user_name).all() # doctest: +NORMALIZE_WHITESPACE >>> del mary.addresses[1] >>> harry_address = Address('harry2@gmail.com') diff --git a/doc/build/content/types.txt b/doc/build/content/types.txt index 4abe508df4..f504f297a1 100644 --- a/doc/build/content/types.txt +++ b/doc/build/content/types.txt @@ -5,49 +5,67 @@ The package `sqlalchemy.types` defines the datatype identifiers which may be use ### Built-in Types {@name=standard} -SQLAlchemy comes with a set of standard generic datatypes, which are defined as classes. +SQLAlchemy comes with a set of standard generic datatypes, which are defined as classes. Types are usually used when defining tables, and can be left as a class or instantiated, for example: -The standard set of generic types are: + {python} + mytable = Table('mytable', metadata, + Column('myid', Integer, primary_key=True), + Column('data', String(30)), + Column('info', Unicode(100)), + Column('value', Number(7,4)) + ) - {python title="package sqlalchemy.types"} - class String(TypeEngine): - def __init__(self, length=None) - - class Integer(TypeEngine) - - class SmallInteger(Integer) - - class Numeric(TypeEngine): - def __init__(self, precision=10, length=2) - - class Float(Numeric): - def __init__(self, precision=10) - - # DateTime, Date and Time types deal with datetime objects from the Python datetime module - class DateTime(TypeEngine) - - class Date(TypeEngine) - - class Time(TypeEngine) - - class Binary(TypeEngine): - def __init__(self, length=None) - - class Boolean(TypeEngine) - - # converts unicode strings to raw bytes - # as bind params, raw bytes to unicode as - # rowset values, using the unicode encoding - # setting on the engine (defaults to 'utf-8') - class Unicode(TypeDecorator): - impl = String - - # uses the pickle protocol to serialize data - # in/out of Binary columns - class PickleType(TypeDecorator): - impl = Binary +Following is a rundown of the standard types. + +#### String + +This type is the base type for all string and character types, such as `Unicode`, `TEXT`, `CLOB`, etc. By default it generates a VARCHAR in DDL. It includes an argument `length`, which indicates the length in characters of the type, as well as `convert_unicode` and `assert_unicode`, which are booleans. `length` will be used as the length argument when generating DDL. If `length` is omitted, the `String` type resolves into the `TEXT` type. + +`convert_unicode=True` indicates that incoming strings, if they are Python `unicode` strings, will be encoded into a raw bytestring using the `encoding` attribute of the dialect (defaults to `utf-8`). Similarly, raw bytestrings coming back from the database will be decoded into `unicode` objects on the way back. + +`assert_unicode` is set to `None` by default. When `True`, it indicates that incoming bind parameters will be checked that they are in fact `unicode` objects, else an error is raised. A value of `'warn'` instead raises a warning. Setting it to `None` indicates that the dialect-level `convert_unicode` setting should take place, whereas setting it to `False` disables it unconditionally (this flag is new as of version 0.4.2). + +Both `convert_unicode` and `assert_unicode` may be set at the engine level as flags to `create_engine()`. + +#### Unicode + +The `Unicode` type is shorthand for `String` with `convert_unicode=True` and `assert_unicode='warn'`. When writing a Unicode-aware application, it is strongly recommended that this type is used, and that only Unicode strings are used in the application. By "Unicode string" we mean a string with a u, i.e. `u'hello'`. Otherwise, particularly when using the ORM, data will be converted to Unicode when it returns from the database, but local data which was generated locally will not be in Unicode format, which can create confusion. + +#### Text / UnicodeText + +These are the "unbounded" versions of ``String`` and ``Unicode``. They have no "length" parameter, and generate a column type of TEXT or CLOB. + +#### Numeric + +Numeric types return `decimal.Decimal` objects by default. The flag `asdecimal=False` may be specified which enables the type to pass data straight through. Numeric also takes "precision" and "scale" arguments which are used when CREATE TABLE is issued. + +#### Float + +Float types return Python floats. Float also takes a "precision" argument which is used when CREATE TABLE is issued. + +#### Datetime/Date/Time + +Date and time types return objects from the Python `datetime` module. Most DBAPIs have built in support for the datetime module, with the noted exception of SQLite. In the case of SQLite, date and time types are stored as strings which are then converted back to datetime objects when rows are returned. -More specific subclasses of these types are available, which various database engines may choose to implement specifically, allowing finer grained control over types: +#### Interval + +The Interval type deals with `datetime.timedelta` objects. In Postgres, the native INTERVAL type is used; for others, the value is stored as a date which is relative to the "epoch" (Jan. 1, 1970). + +#### Binary + +The Binary type generates BLOB or BYTEA when tables are created, and also converts incoming values using the `Binary` callable provided by each DBAPI. + +#### Boolean + +Boolean typically uses BOOLEAN or SMALLINT on the CREATE TABLE side, and returns Python `True` or `False`. + +#### PickleType + +PickleType builds upon the Binary type to apply Python's `pickle.dumps()` to incoming objects, and `pickle.loads()` on the way out, allowing any pickleable Python object to be stored as a serialized binary field. + +#### SQL-Specific Types {@name=sqlspecific} + +These are subclasses of the generic types and include: {python} class FLOAT(Numeric) @@ -63,19 +81,7 @@ More specific subclasses of these types are available, which various database en class BLOB(Binary) class BOOLEAN(Boolean) -When using a specific database engine, these types are adapted even further via a set of database-specific subclasses defined by the database engine. -There may eventually be more type objects that are defined for specific databases. An example of this would be Postgres' Array type. - -Type objects are specified to table meta data using either the class itself, or an instance of the class. Creating an instance of the class allows you to specify parameters for the type, such as string length, numerical precision, etc.: - - {python} - mytable = Table('mytable', engine, - # define type using a class - Column('my_id', Integer, primary_key=True), - - # define type using an object instance - Column('value', Number(7,4)) - ) +The idea behind the SQL-specific types is that a CREATE TABLE statement would generate the exact type specified. ### Dialect Specific Types {@name=dialect} @@ -102,21 +108,29 @@ Or some postgres types: ### Creating your Own Types {@name=custom} -User-defined types can be created, to support either database-specific types, or customized pre-processing of query parameters as well as post-processing of result set data. You can make your own classes to perform these operations. To augment the behavior of a `TypeEngine` type, such as `String`, the `TypeDecorator` class is used: +User-defined types can be created which can augment the bind parameter and result processing capabilities of the built in types. This is usually achieved using the `TypeDecorator` class, which "decorates" the behavior of any existing type. As of version 0.4.2, the new `process_bind_param()` and `process_result_value()` methods should be used: {python} import sqlalchemy.types as types class MyType(types.TypeDecorator): - """basic type that decorates String, prefixes values with "PREFIX:" on + """a type that decorates Unicode, prefixes values with "PREFIX:" on the way in and strips it off on the way out.""" - impl = types.String - def convert_bind_param(self, value, engine): + + impl = types.Unicode + + def process_bind_param(self, value, engine): return "PREFIX:" + value - def convert_result_value(self, value, engine): - return value[7:] -The `PickleType` class is an instance of `TypeDecorator` already and can be subclassed directly. + def process_result_value(self, value, engine): + return value[7:] + + def copy(self): + return MyType(self.impl.length) + +Note that the "old" way to process bind parameters and result values, the `convert_bind_param()` and `convert_result_value()` methods, are still available. The downside of these is that when using a type which already processes data such as the `Unicode` type, you need to call the superclass version of these methods directly. Using `process_bind_param()` and `process_result_value()`, user-defined code can return and receive the desired Python data directly. + +As of version 0.4.2, `TypeDecorator` should generally be used for any user-defined type which redefines the behavior of another type, including other `TypeDecorator` subclasses such as `PickleType`, and the new `process_...()` methods described above should be used. To build a type object from scratch, which will not have a corresponding database-specific implementation, subclass `TypeEngine`: @@ -126,14 +140,17 @@ To build a type object from scratch, which will not have a corresponding databas class MyType(types.TypeEngine): def __init__(self, precision = 8): self.precision = precision + def get_col_spec(self): return "MYTYPE(%s)" % self.precision + def convert_bind_param(self, value, engine): return value + def convert_result_value(self, value, engine): return value -Once you make your type, its immediately useable: +Once you make your type, it's immediately useable: {python} table = Table('foo', meta, @@ -141,4 +158,4 @@ Once you make your type, its immediately useable: Column('data', MyType(16)) ) - \ No newline at end of file + diff --git a/doc/build/content/unitofwork.txt b/doc/build/content/unitofwork.txt deleted file mode 100644 index ef01189011..0000000000 --- a/doc/build/content/unitofwork.txt +++ /dev/null @@ -1,493 +0,0 @@ -[alpha_api]: javascript:alphaApi() -[alpha_implementation]: javascript:alphaImplementation() - -Session / Unit of Work {@name=unitofwork} -============ - -### Overview {@name=overview} - -The concept behind Unit of Work is to track modifications to a field of objects, and then be able to flush those changes to the database in a single operation. Theres a lot of advantages to this, including that your application doesn't need to worry about individual save operations on objects, nor about the required order for those operations, nor about excessive repeated calls to save operations that would be more efficiently aggregated into one step. It also simplifies database transactions, providing a neat package with which to insert into the traditional database begin/commit phase. - -SQLAlchemy's unit of work includes these functions: - -* The ability to monitor scalar and list attributes on object instances, as well as object creates. This is handled via the attributes package. -* The ability to maintain and process a list of modified objects, and based on the relationships set up by the mappers for those objects as well as the foreign key relationships of the underlying tables, figure out the proper order of operations so that referential integrity is maintained, and also so that on-the-fly values such as newly created primary keys can be propigated to dependent objects that need them before they are saved. The central algorithm for this is the *topological sort*. -* The ability to define custom functionality that occurs within the unit-of-work flush phase, such as "before insert", "after insert", etc. This is accomplished via MapperExtension. -* an Identity Map, which is a dictionary storing the one and only instance of an object for a particular table/primary key combination. This allows many parts of an application to get a handle to a particular object without any chance of modifications going to two different places. -* The sole interface to the unit of work is provided via the `Session` object. Transactional capability is included. - -### Object States {@name=states} - -When dealing with mapped instances with regards to Sessions, an instance may be *attached* or *unattached* to a particular Session. An instance also may or may not correspond to an actual row in the database. These conditions break up into four distinct states: - -* *Transient* - a transient instance exists within memory only and is not associated with any Session. It also has no database identity and does not have a corresponding record in the database. When a new instance of a class is constructed, and no default session context exists with which to automatically attach the new instance, it is a transient instance. The instance can then be saved to a particular session in which case it becomes a *pending* instance. If a default session context exists, new instances are added to that Session by default and therefore become *pending* instances immediately. - -* *Pending* - a pending instance is a Session-attached object that has not yet been assigned a database identity. When the Session is flushed (i.e. changes are persisted to the database), a pending instance becomes persistent. - -* *Persistent* - a persistent instance has a database identity and a corresponding record in the database, and is also associated with a particular Session. By "database identity" we mean the object is associated with a table or relational concept in the database combined with a particular primary key in that table. Objects that are loaded by SQLAlchemy in the context of a particular session are automatically considered persistent, as are formerly pending instances which have been subject to a session `flush()`. - -* *Detached* - a detached instance is an instance which has a database identity and corresponding row in the database, but is not attached to any Session. This occurs when an instance has been removed from a Session, either because the session itself was cleared or closed, or the instance was explicitly removed from the Session. The object can be re-attached to a session in which case it becomes Persistent again; any un-persisted changes that exist on the instance, whether they occurred during its previous persistent state or during its detached state will be detected and maintained by the new session. Detached instances are useful when an application needs to represent a long-running operation across multiple Sessions, needs to store an object in a serialized state and then restore it later (such as within an HTTP "session" object), or in some cases where code needs to load instances locally which will later be associated with some other Session. - -### Acquiring a Session {@name=getting} - -A new Session object is constructed via the `create_session()` function: - - {python} - session = create_session() - -A common option used with `create_session()` is to specify a specific `Engine` or `Connection` to be used for all operations performed by this Session: - - {python} - # create an engine - e = create_engine('postgres://some/url') - - # create a Session that will use this engine for all operations. - # it will open and close Connections as needed. - session = create_session(bind=e) - - # open a Connection - conn = e.connect() - - # create a Session that will use this specific Connection for all operations - session = create_session(bind=conn) - - -The session to which an object is attached can be acquired via the `object_session()` function, which returns the appropriate `Session` if the object is pending or persistent, or `None` if the object is transient or detached: - - {python} - session = object_session(obj) - -Session Facts: - - * the Session object is **not threadsafe**. For thread-local management of Sessions, the recommended approach is to use the [plugins_sessioncontext](rel:plugins_sessioncontext) extension module. - -We will now cover some of the key concepts used by Sessions and its underlying Unit of Work. - -### Introduction to the Identity Map {@name=identitymap} - -A primary concept of the Session's underlying Unit of Work is that it is keeps track of all persistent instances; recall that a persistent instance has a database identity and is attached to a Session. In particular, the Unit of Work must ensure that only *one* copy of a particular persistent instance exists within the Session at any given time. The UOW accomplishes this task using a dictionary known as an *Identity Map*. - -When a `Query` is used to issue `select` or `get` requests to the database, it will in nearly all cases result in an actual SQL execution to the database, and a corresponding traversal of rows received from that execution. However, when the underlying mapper actually *creates* objects corresponding to the result set rows it receives, it will check the session's identity map first before instantating a new object, and return the same instance already present in the identity map if it already exists, essentially *ignoring* the object state represented by that row. There are several ways to override this behavior and truly refresh an already-loaded instance which are described later, but the main idea is that once your instance is loaded into a particular Session, it will *never change* its state without your explicit approval, regardless of what the database says about it. - -For example; below, two separate calls to load an instance with database identity "15" are issued, and the results assigned to two separate variables. However, since the same `Session` was used, the two instances are the same instance: - - {python} - mymapper = mapper(MyClass, mytable) - - session = create_session() - obj1 = session.query(MyClass).filter(mytable.c.id==15).first() - obj2 = session.query(MyClass).filter(mytable.c.id==15).first() - - >>> obj1 is obj2 - True - -The Identity Map is an instance of `dict` by default. As an option, you can specify the flag `weak_identity_map=True` to the `create_session` function so that it will use a `weakref.WeakValueDictionary`, so that when an in-memory object falls out of scope, it will be removed automatically, thereby providing some automatic management of memory. However, this may not be instant if there are circular references upon the object. To guarantee that an instance is removed from the identity map before removing references to it, use the `expunge()` method, described later, to remove it. Additionally, note that an object that has changes marked on it (i.e. "dirty") can still fall out of scope when using `weak_identity_map`. - -The Session supports an iterator interface in order to see all objects in the identity map: - - {python} - for obj in session: - print obj - -As well as `__contains__()`: - - {python} - if obj in session: - print "Object is present" - -The identity map itself is accessible via the `identity_map` accessor: - - {python} - >>> session.identity_map.values() - [<__main__.User object at 0x712630>, <__main__.Address object at 0x712a70>] - -The identity of each object instance is available via the `_instance_key` property attached to each object instance, and is a tuple consisting of the object's class and an additional tuple of primary key values, in the order that they appear within the table definition: - - {python} - >>> obj._instance_key - (, (7,)) - -At the moment that an object is assigned this key within a `flush()` operation, it is also added to the session's identity map. - -The `get()` method on `Query`, which retrieves an object based on primary key identity, also checks in the Session's identity map first to save a database round-trip if possible. In the case of an object lazy-loading a single child object, the `get()` method is used as well, so scalar-based lazy loads may in some cases not query the database; this is particularly important for backreference relationships as it can save a lot of queries. - -### Whats Changed ? {@name=changed} - -The next concept is that in addition to the `Session` storing a record of all objects loaded or saved, it also stores lists of all *newly created* (i.e. pending) objects and lists of all persistent objects that have been marked as *deleted*. These lists are used when a `flush()` call is issued to save all changes. During a flush operation, it also scans its list of persistent instances for changes which are marked as dirty. - -These records are all tracked by collection functions that are also viewable off the `Session` as properties: - - {python} - # pending objects recently added to the Session - session.new - - # persistent objects which currently have changes detected - # (this collection is now created on the fly each time the property is called) - session.dirty - - # persistent objects that have been marked as deleted via session.delete(obj) - session.deleted - -Note that if a session is created with the `weak_identity_map` flag, an item which is marked as "dirty" will be silently removed from the session if the item falls out of scope in the user application. This is because the unit of work does not look for "dirty" changes except for within a flush operation (or any time the session.dirty collection is accessed). - -As for objects inside of `new` and `deleted`, if you abandon all references to new or modified objects within a session, *they are still present* in either of those two lists, and will be saved on the next flush operation, unless they are removed from the Session explicitly (more on that later). - -### The Session API {@name=api} - -#### query() {@name=query} - -The `query()` function takes one or more classes and/or mappers, along with an optional `entity_name` parameter, and returns a new `Query` object which will issue mapper queries within the context of this Session. For each mapper is passed, the Query uses that mapper. For each class, the Query will locate the primary mapper for the class using `class_mapper()`. - - {python} - # query from a class - session.query(User).filter_by(name='ed').all() - - # query with multiple classes, returns tuples - session.query(User, Address).join('addresses').filter_by(name='ed').all() - - # query from a mapper - query = session.query(usermapper) - x = query.get(1) - - # query from a class mapped with entity name 'alt_users' - q = session.query(User, entity_name='alt_users') - y = q.options(eagerload('orders')).all() - -`entity_name` is an optional keyword argument sent with a class object, in order to further qualify which primary mapper to be used; this only applies if there was a `Mapper` created with that particular class/entity name combination, else an exception is raised. All of the methods on Session which take a class or mapper argument also take the `entity_name` argument, so that a given class can be properly matched to the desired primary mapper. - -All instances retrieved by the returned `Query` object will be stored as persistent instances within the originating `Session`. - -#### get() {@name=get} - -Given a class or mapper, a scalar or tuple-based identity, and an optional `entity_name` keyword argument, creates a `Query` corresponding to the given mapper or class/entity_name combination, and calls the `get()` method with the given identity value. If the object already exists within this Session, it is simply returned, else it is queried from the database. If the instance is not found, the method returns `None`. - - {python} - # get Employer primary key 5 - employer = session.get(Employer, 5) - - # get Report composite primary key 7,12, using mapper 'report_mapper_b' - report = session.get(Report, (7,12), entity_name='report_mapper_b') - - -#### load() {@name=load} - -load() is similar to get() except it will raise an exception if the instance does not exist in the database. It will also load the object's data from the database in all cases, and **overwrite** all changes on the object if it already exists in the session with the latest data from the database. - - {python} - # load Employer primary key 5 - employer = session.load(Employer, 5) - - # load Report composite primary key 7,12, using mapper 'report_mapper_b' - report = session.load(Report, (7,12), entity_name='report_mapper_b') - -#### save() {@name=save} - -save() is called with a single transient (unsaved, unattached) instance as an argument, which is then added to the Session and becomes pending. When the session is next `flush`ed, the instance will be saved to the database uponwhich it becomes persistent (saved, attached). If the given instance is not transient, meaning it is either attached to an existing Session or it has a database identity, an exception is raised. - - {python} - user1 = User(name='user1') - user2 = User(name='user2') - session.save(user1) - session.save(user2) - - session.flush() # write changes to the database - -save() is called automatically for new instances by the classes' associated mapper, if a default Session context is in effect (such as a thread-local session), which means that newly created instances automatically become pending. If there is no default session available, then the instance remains transient (unattached) until it is explicitly added to a Session via the save() method. - -A transient instance also can be automatically `save`ed if it is associated with a parent object which specifies `save-update` within its `cascade` rules, and that parent is already attached or becomes attached to a Session. For more information on `cascade`, see the next section. - -The `save_or_update()` method, covered later, is a convenience method which will call the `save()` or `update()` methods appropriately dependening on whether or not the instance has a database identity (but the instance still must be unattached). - -#### flush() {@name=flush} - -This is the main gateway to what the Unit of Work does best, which is save everything ! It should be clear by now what a flush looks like: - - {python} - session.flush() - -It also can be called with a list of objects; in this form, the flush operation will be limited only to the objects specified in the list, as well as any child objects within `private` relationships for a delete operation: - - {python} - # saves only user1 and address2. all other modified - # objects remain present in the session. - session.flush([user1, address2]) - -This second form of flush should be used carefully as it will not necessarily locate other dependent objects within the session, whose database representation may have foreign constraint relationships with the objects being operated upon. - -Theres also a way to have `flush()` called automatically before each query; this is called "autoflush" and is described below. - -##### Notes on Flush {@name=whatis} - -A common misconception about the `flush()` operation is that once performed, the newly persisted instances will automatically have related objects attached to them, based on the values of primary key identities that have been assigned to the instances before they were persisted. An example would be, you create a new `Address` object, set `address.user_id` to 5, and then `flush()` the session. The erroneous assumption would be that there is now a `User` object of identity "5" attached to the `Address` object, but in fact this is not the case. If you were to `refresh()` the `Address`, invalidating its current state and re-loading, *then* it would have the appropriate `User` object present. - -This misunderstanding is related to the observed behavior of backreferences ([datamapping_relations_backreferences](rel:datamapping_relations_backreferences)), which automatically associates an instance "A" with another instance "B", in response to the manual association of instance "B" to instance "A" by the user. The backreference operation occurs completely externally to the `flush()` operation, and is pretty much the only example of a SQLAlchemy feature that manipulates the relationships of persistent objects. - -The primary guideline for dealing with `flush()` is, the developer is responsible for maintaining in-memory objects and their relationships to each other, the unit of work is responsible for maintaining the database representation of the in-memory objects. The typical pattern is that the manipulation of objects *is* the way that changes get communicated to the unit of work, so that when the flush occurs, the objects are already in their correct in-memory representation and problems dont arise. The manipulation of identifier attributes like integer key values as well as deletes in particular are a frequent source of confusion. - -#### close() {@name=close} - -This method first calls `clear()`, removing all objects from this `Session`, and then ensures that any transactional resources are closed. - -#### delete() {@name=delete} - -The `delete` method places an instance into the Unit of Work's list of objects to be marked as deleted: - - {python} - # mark two objects to be deleted - session.delete(obj1) - session.delete(obj2) - - # flush - session.flush() - -The delete operation will have an effect on instances that are attached to the deleted instance according to the `cascade` style of the relationship; cascade rules are described further in the following section. By default, associated instances may need to be updated in the database to reflect that they no longer are associated with the parent object, before the parent is deleted. If the relationship specifies `cascade="delete"`, then the associated instance will also be deleted upon flush, assuming it is still attached to the parent. If the relationship additionally includes the `delete-orphan` cascade style, the associated instance will be deleted if it is still attached to the parent, or is unattached to any other parent. - -The `delete()` operation has no relationship to the in-memory status of the instance, including usage of the `del` Python statement. An instance marked as deleted and flushed will still exist within memory until references to it are freed; similarly, removing an instance from memory via the `del` statement will have no effect, since the persistent instance will still be referenced by its Session. Obviously, if the instance is removed from the Session and then totally dereferenced, it will no longer exist in memory, but also won't exist in any Session and is therefore not deleted from the database. - -Note that the "in-memory status" of an instance also refers to its presence in any other collection. SQLAlchemy does not track the collections to which an instance is a member, and will not remove an instance from its parent collections that were not directly involved in a deletion operation. The operational and memory overhead implied by this would be too great (such as, if an object belonged to hundreds of collections). This means if an object `A` is attached to both an object `B` and an object `C`, if you `delete()` `A` and flush, `A` still remains attached to both `B` and `C` in a deleted state and must be removed by the application. Similarly, if a delete on `B` cascades to `A`, this **does not** affect `A` still being present on `C` - again it must be manually removed. - - -#### clear() {@name=clear} - -This method detaches all instances from the Session, sending them to the detached or transient state as applicable, and replaces the underlying UnitOfWork with a new one. - - {python} - session.clear() - -The `clear()` method is particularly useful with a "default context" session such as a thread-local session, which can stay attached to the current thread to handle a new field of objects without having to re-attach a new Session. - -#### refresh() / expire() {@name=refreshexpire} - -To assist with the Unit of Work's "sticky" behavior, individual objects can have all of their attributes immediately re-loaded from the database, or marked as "expired" which will cause a re-load to occur upon the next access of any of the object's mapped attributes. This includes all relationships, so lazy-loaders will be re-initialized, eager relationships will be repopulated. Any changes marked on the object are discarded: - - {python} - # immediately re-load attributes on obj1, obj2 - session.refresh(obj1) - session.refresh(obj2) - - # expire objects obj1, obj2, attributes will be reloaded - # on the next access: - session.expire(obj1) - session.expire(obj2) - -#### expunge() {@name=expunge} - -Expunge removes an object from the Session, sending persistent instances to the detached state, and pending instances to the transient state: - - {python} - session.expunge(obj1) - -Use `expunge` when youd like to remove an object altogether from memory, such as before calling `del` on it, which will prevent any "ghost" operations occuring when the session is flushed. - -#### bind\_mapper() / bind\_table() {@name=bind} - -Both of these methods receive two arguments; in the case of `bind_mapper()`, it is a `Mapper` and an `Engine` or `Connection` instance; in the case of `bind_table()`, it is a `Table` instance or other `Selectable` (such as an `Alias`, `Select`, etc.), and an `Engine` or `Connection` instance. - - {python} - engine1 = create_engine('sqlite:///file1.db') - engine2 = create_engine('mysql://localhost') - - sqlite_conneciton = engine1.connect() - - sess = create_session() - - sess.bind_mapper(mymapper, sqlite_connection) # bind mymapper operations to a single SQLite connection - sess.bind_table(email_addresses_table, engine2) # bind operations with the email_addresses_table to mysql - -Normally, when a `Session` is created via `create_session()` with no arguments, the Session has no awareness of individual `Engines`, and when mappers use the `Session` to retrieve connections, the underlying `MetaData` each `Table` is associated with is expected to be "bound" to an `Engine`, else no engine can be located and an exception is raised. A second form of `create_session()` takes the argument `bind=engine_or_connection`, where all SQL operations performed by this `Session` use the single `Engine` or `Connection` (collectively known as a `Connectable`) passed to the constructor. With `bind_mapper()` and `bind_table()`, the operations of individual mapper and/or tables are bound to distinct engines or connections, thereby overriding not only the engine which may be "bound" to the underlying `MetaData`, but also the `Engine` or `Connection` which may have been passed to the `create_session()` function. Configurations which interact with multiple explicit database connections at one time must use either or both of these methods in order to associate `Session` operations with the appropriate connection resource. - -Binding a `Mapper` to a resource takes precedence over a `Table` bind, meaning if mapper A is associated with table B, and the Session binds mapper A to connection X and table B to connection Y, an operation with mapper A will use connection X, not connection Y. - -#### update() {@name=update} - -The update() method is used *only* with detached instances. A detached instance only exists if its `Session` was cleared or closed, or the instance was `expunge()`d from its session. `update()` will re-attach the detached instance with this Session, bringing it back to the persistent state, and allowing any changes on the instance to be saved when the `Session` is next `flush`ed. If the instance is already attached to an existing `Session`, an exception is raised. - -A detached instance also can be automatically `update`ed if it is associated with a parent object which specifies `save-update` within its `cascade` rules, and that parent is already attached or becomes attached to a Session. For more information on `cascade`, see the next section. - -The `save_or_update()` method is a convenience method which will call the `save()` or `update()` methods appropriately dependening on whether or not the instance has a database identity (but the instance still must be unattached). - -#### save\_or\_update() {@name=saveorupdate} - -This method is a combination of the `save()` and `update()` methods, which will examine the given instance for a database identity (i.e. if it is transient or detached), and will call the implementation of `save()` or `update()` as appropriate. Use `save_or_update()` to add unattached instances to a session when you're not sure if they were newly created or not. Like `save()` and `update()`, `save_or_update()` cascades along the `save-update` cascade indicator, described in the `cascade` section below. - -#### merge() {@name=merge} - -`merge()` is used to return the persistent version of an instance that is not attached to this Session. When passed an instance, if an instance with its database identity already exists within this Session, it is returned. If the instance does not exist in this Session, it is loaded from the database and then returned. - -A future version of `merge()` will also update the Session's instance with the state of the given instance (hence the name "merge"). - -This method is useful for bringing in objects which may have been restored from a serialization, such as those stored in an HTTP session: - - {python} - # deserialize an object - myobj = pickle.loads(mystring) - - # "merge" it. if the session already had this object in the - # identity map, then you get back the one from the current session. - myobj = session.merge(myobj) - -Note that `merge()` *does not* associate the given instance with the Session; it remains detached (or attached to whatever Session it was already attached to). - -### Cascade rules {@name=cascade} - -Mappers support the concept of configurable *cascade* behavior on `relation()`s. This behavior controls how the Session should treat the instances that have a parent-child relationship with another instance that is operated upon by the Session. Cascade is indicated as a comma-separated list of string keywords, with the possible values `all`, `delete`, `save-update`, `refresh-expire`, `merge`, `expunge`, and `delete-orphan`. - -Cascading is configured by setting the `cascade` keyword argument on a `relation()`: - - {python} - mapper(Order, order_table, properties={ - 'items' : relation(Item, items_table, cascade="all, delete-orphan"), - 'customer' : relation(User, users_table, user_orders_table, cascade="save-update"), - }) - -The above mapper specifies two relations, `items` and `customer`. The `items` relationship specifies "all, delete-orphan" as its `cascade` value, indicating that all `save`, `update`, `merge`, `expunge`, `refresh` `delete` and `expire` operations performed on a parent `Order` instance should also be performed on the child `Item` instances attached to it (`save` and `update` are cascaded using the `save_or_update()` method, so that the database identity of the instance doesn't matter). The `delete-orphan` cascade value additionally indicates that if an `Item` instance is no longer associated with an `Order`, it should also be deleted. The "all, delete-orphan" cascade argument allows a so-called *lifecycle* relationship between an `Order` and an `Item` object. - -The `customer` relationship specifies only the "save-update" cascade value, indicating most operations will not be cascaded from a parent `Order` instance to a child `User` instance, except for if the `Order` is attached with a particular session, either via the `save()`, `update()`, or `save-update()` method. - -Additionally, when a child item is attached to a parent item that specifies the "save-update" cascade value on the relationship, the child is automatically passed to `save_or_update()` (and the operation is further cascaded to the child item). - -Note that cascading doesn't do anything that isn't possible by manually calling Session methods on individual instances within a hierarchy, it merely automates common operations on a group of associated instances. - -The default value for `cascade` on `relation()`s is `save-update`, and the `private=True` keyword argument is a synonym for `cascade="all, delete-orphan"`. - -### Using Session Transactions {@name=transaction} - -The Session can manage transactions automatically, including across multiple engines. When the Session is in a transaction, as it receives requests to execute SQL statements, it adds each indivdual Connection/Engine encountered to its transactional state. At commit time, all unflushed data is flushed, and each individual transaction is committed. If the underlying databases support two-phase semantics, this may be used by the Session as well if two-phase transactions are enabled. - -The easiest way to use a Session with transactions is just to declare it as transactional. The session will remain in a transaction at all times: - - {python} - sess = create_session(transactional=True) - item1 = sess.query(Item).get(1) - item2 = sess.query(Item).get(2) - item1.foo = 'bar' - item2.bar = 'foo' - - # commit- will immediately go into a new transaction afterwards - sess.commit() - -Alternatively, a transaction can be begun explicitly using `begin()`: - - {python} - sess = create_session() - sess.begin() - try: - item1 = sess.query(Item).get(1) - item2 = sess.query(Item).get(2) - item1.foo = 'bar' - item2.bar = 'foo' - except: - sess.rollback() - raise - sess.commit() - -Session also supports Python 2.5's with statement so that the example above can be written as: - - {python} - sess = create_session() - with sess.begin(): - item1 = sess.query(Item).get(1) - item2 = sess.query(Item).get(2) - item1.foo = 'bar' - item2.bar = 'foo' - -For MySQL and Postgres (and soon Oracle), "nested" transactions can be accomplished which use SAVEPOINT behavior, via the `begin_nested()` method: - - {python} - sess = create_session() - sess.begin() - sess.save(u1) - sess.save(u2) - sess.flush() - - sess.begin_nested() # establish a savepoint - sess.save(u3) - sess.rollback() # rolls back u3, keeps u1 and u2 - - sess.commit() # commits u1 and u2 - -Finally, for MySQL, Postgres, and soon Oracle as well, the session can be instructed to use two-phase commit semantics using the flag `twophase=True`, which coordinates transactions across multiple databases: - - {python} - engine1 = create_engine('postgres://db1') - engine2 = create_engine('postgres://db2') - - sess = create_session(twophase=True, transactional=True) - - # bind User operations to engine 1 - sess.bind_mapper(User, engine1) - - # bind Account operations to engine 2 - sess.bind_mapper(Account, engine2) - - # .... work with accounts and users - - # commit. session will issue a flush to all DBs, and a prepare step to all DBs, - # before committing both transactions - sess.commit() - -#### AutoFlush {@name=autoflush} - -A transactional session can also conveniently issue `flush()` calls before each query. This allows you to immediately have DB access to whatever has been saved to the session. Creating the session with `autoflush=True` implies `transactional=True`: - - {python} - sess = create_session(autoflush=True) - u1 = User(name='jack') - sess.save(u1) - - # reload user1 - u2 = sess.query(User).filter_by(name='jack').one() - assert u2 is u1 - - # commit session, flushes whatever is remaining - sess.commit() - -#### Using SQL with Sessions and Transactions {@name=sql} - -SQL constructs and string statements can be executed via the `Session`. You'd want to do this normally when your `Session` is transactional and youd like your free-standing SQL statements to participate in the same transaction. - -The two ways to do this are to use the connection/execution services of the Session, or to have your Session participate in a regular SQL transaction. - -First, a Session thats associated with an Engine or Connection can execute statements immediately (whether or not its transactional): - - {python} - sess = create_session(bind=engine, transactional=True) - result = sess.execute("select * from table where id=:id", {'id':7}) - result2 = sess.execute(select([mytable], mytable.c.id==7)) - -To get at the current connection used by the session, which will be part of the current transaction if one is in progress, use `connection()`: - - connection = sess.connection() - -A second scenario is that of a Session which is not directly bound to a connectable. This session executes statements relative to a particular `Mapper`, since the mappers are bound to tables which are in turn bound to connectables via their `MetaData` (either the session or the mapped tables need to be bound). In this case, the Session can conceivably be associated with multiple databases through different mappers; so it wants you to send along a `mapper` argument, which can be any mapped class or mapper instance: - - {python} - sess = create_session(transactional=True) - result = sess.execute("select * from table where id=:id", {'id':7}, mapper=MyMappedClass) - result2 = sess.execute(select([mytable], mytable.c.id==7), mapper=MyMappedClass) - - connection = sess.connection(MyMappedClass) - -The third scenario is when you are using `Connection` and `Transaction` yourself, and want the `Session` to participate. This is easy, as you just bind the `Session` to the connection: - - {python} - conn = engine.connect() - trans = conn.begin() - sess = create_session(bind=conn) - # ... etc - trans.commit() - -It's safe to use a `Session` which is transactional or autoflushing, as well as to call `begin()`/`commit()` on the session too; the outermost Transaction object, the one we declared explicitly, controls the scope of the transaction. - -When using the `threadlocal` engine context, things are that much easier; the `Session` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it: - - {python} - engine = create_engine('postgres://mydb', strategy="threadlocal") - engine.begin() - - sess = create_session() # session takes place in the transaction like everyone else - - # ... go nuts - - engine.commit() # commit - diff --git a/doc/build/gen_docstrings.py b/doc/build/gen_docstrings.py index 346497d3ec..d9bad13841 100644 --- a/doc/build/gen_docstrings.py +++ b/doc/build/gen_docstrings.py @@ -2,7 +2,9 @@ from toc import TOCElement import docstring import re -from sqlalchemy import schema, types, ansisql, engine, sql, pool, orm, exceptions, databases +from sqlalchemy import schema, types, engine, sql, pool, orm, exceptions, databases, interfaces +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, strategies, threadlocal, url import sqlalchemy.orm.shard import sqlalchemy.ext.sessioncontext as sessioncontext import sqlalchemy.ext.selectresults as selectresults @@ -10,6 +12,7 @@ import sqlalchemy.ext.orderinglist as orderinglist import sqlalchemy.ext.associationproxy as associationproxy import sqlalchemy.ext.assignmapper as assignmapper import sqlalchemy.ext.sqlsoup as sqlsoup +import sqlalchemy.ext.declarative as declarative def make_doc(obj, classes=None, functions=None, **kwargs): """generate a docstring.ObjectDoc structure for an individual module, list of classes, and list of functions.""" @@ -20,31 +23,32 @@ def make_all_docs(): """generate a docstring.AbstractDoc structure.""" print "generating docstrings" objects = [ - make_doc(obj=sql,include_all_classes=True), - make_doc(obj=schema), + make_doc(obj=engine), + make_doc(obj=default), + make_doc(obj=strategies), + make_doc(obj=threadlocal), + make_doc(obj=url), + make_doc(obj=exceptions), + make_doc(obj=interfaces), make_doc(obj=pool), + make_doc(obj=schema), + #make_doc(obj=sql,include_all_classes=True), + make_doc(obj=compiler), + make_doc(obj=expression,include_all_classes=True), make_doc(obj=types), - make_doc(obj=engine), - make_doc(obj=engine.url), - make_doc(obj=engine.strategies), - make_doc(obj=engine.default), - make_doc(obj=engine.threadlocal), - make_doc(obj=ansisql), make_doc(obj=orm), make_doc(obj=orm.collections, classes=[orm.collections.collection, orm.collections.MappedCollection, orm.collections.CollectionAdapter]), make_doc(obj=orm.interfaces), - make_doc(obj=orm.mapperlib, classes=[orm.mapperlib.MapperExtension, orm.mapperlib.Mapper]), + make_doc(obj=orm.mapperlib, classes=[orm.mapperlib.Mapper]), make_doc(obj=orm.properties), - make_doc(obj=orm.query, classes=[orm.query.Query, orm.query.QueryContext, orm.query.SelectionContext]), - make_doc(obj=orm.session, classes=[orm.session.Session, orm.session.SessionTransaction]), + make_doc(obj=orm.query, classes=[orm.query.Query]), + make_doc(obj=orm.session, classes=[orm.session.Session, orm.session.SessionExtension]), make_doc(obj=orm.shard), - make_doc(obj=exceptions), - make_doc(obj=assignmapper), + make_doc(obj=declarative), make_doc(obj=associationproxy, classes=[associationproxy.AssociationProxy]), make_doc(obj=orderinglist, classes=[orderinglist.OrderingList]), - make_doc(obj=sessioncontext), make_doc(obj=sqlsoup), ] + [make_doc(getattr(__import__('sqlalchemy.databases.%s' % m).databases, m)) for m in databases.__all__] return objects @@ -52,7 +56,7 @@ def make_all_docs(): def create_docstring_toc(data, root): """given a docstring.AbstractDoc structure, create new TOCElement nodes corresponding to the elements and cross-reference them back to the doc structure.""" - root = TOCElement("docstrings", name="docstrings", description="Generated Documentation", parent=root, requires_paged=True) + root = TOCElement("docstrings", name="docstrings", description="API Documentation", parent=root, requires_paged=True) files = [] def create_obj_toc(obj, toc): if obj.isclass: diff --git a/doc/build/genhtml.py b/doc/build/genhtml.py index ddc2e8a926..e28f866095 100644 --- a/doc/build/genhtml.py +++ b/doc/build/genhtml.py @@ -1,8 +1,9 @@ #!/usr/bin/env python import sys,re,os,shutil +from os import path import cPickle as pickle -sys.path = ['../../lib', './lib/'] + sys.path +sys.path = ['../../lib', './lib'] + sys.path import sqlalchemy import gen_docstrings, read_markdown, toc @@ -14,13 +15,13 @@ import optparse files = [ 'index', 'documentation', - 'tutorial', + 'intro', + 'ormtutorial', + 'sqlexpression', + 'mappers', + 'session', 'dbengine', 'metadata', - 'sqlconstruction', - 'datamapping', - 'unitofwork', - 'adv_datamapping', 'types', 'pooling', 'plugins', @@ -31,10 +32,14 @@ post_files = [ 'copyright' ] +v = open(path.join(path.dirname(__file__), '..', '..', 'VERSION')) +VERSION = v.readline().strip() +v.close() + parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]") parser.add_option("--file", action="store", dest="file", help="only generate file ") parser.add_option("--docstrings", action="store_true", dest="docstrings", help="only generate docstrings") -parser.add_option("--version", action="store", dest="version", default=sqlalchemy.__version__, help="version string") +parser.add_option("--version", action="store", dest="version", default=VERSION, help="version string") (options, args) = parser.parse_args() if options.file: @@ -45,6 +50,7 @@ else: title='SQLAlchemy 0.4 Documentation' version = options.version + root = toc.TOCElement('', 'root', '', version=version, doctitle=title) shutil.copy('./content/index.html', './output/index.html') diff --git a/doc/build/lib/docstring.py b/doc/build/lib/docstring.py index f0aebe92ba..819296ccf1 100644 --- a/doc/build/lib/docstring.py +++ b/doc/build/lib/docstring.py @@ -34,7 +34,7 @@ class ObjectDoc(AbstractDoc): for x in objects if getattr(obj,x,None) is not None and (isinstance(getattr(obj,x), types.FunctionType)) - and not getattr(obj,x).__name__[0] == '_' + and not self._is_private_name(getattr(obj,x).__name__) ] if sort: functions.sort(lambda a, b: cmp(a.__name__, b.__name__)) @@ -43,7 +43,7 @@ class ObjectDoc(AbstractDoc): if getattr(obj,x,None) is not None and (isinstance(getattr(obj,x), types.TypeType) or isinstance(getattr(obj,x), types.ClassType)) - and (self.include_all_classes or not getattr(obj,x).__name__[0] == '_') + and (self.include_all_classes or not self._is_private_name(getattr(obj,x).__name__)) ] classes = list(set(classes)) if sort: @@ -53,14 +53,14 @@ class ObjectDoc(AbstractDoc): functions = ( [getattr(obj, x).im_func for x in obj.__dict__.keys() if isinstance(getattr(obj,x), types.MethodType) and - (getattr(obj, x).__name__ == '__init__' or not getattr(obj,x).__name__[0] == '_') + (getattr(obj, x).__name__ == '__init__' or not self._is_private_name(getattr(obj,x).__name__)) ] + [(x, getattr(obj, x)) for x in obj.__dict__.keys() if _is_property(getattr(obj,x)) and - not x[0] == '_' + not self._is_private_name(x) ] ) - functions.sort(lambda a, b: cmp(getattr(a, '__name__', None) or a[0], getattr(b, '__name__', None) or b[0] )) + functions.sort(_method_sort) if classes is None: classes = [] @@ -86,21 +86,32 @@ class ObjectDoc(AbstractDoc): self.doc = obj.__doc__ self.functions = [] - if not self.isclass and len(functions): + if not self.isclass: for func in functions: self.functions.append(FunctionDoc(func)) else: - if len(functions): - for func in functions: - if isinstance(func, types.FunctionType): - self.functions.append(FunctionDoc(func)) - elif isinstance(func, tuple): - self.functions.append(PropertyDoc(func[0], func[1])) + for func in functions: + if isinstance(func, types.FunctionType): + self.functions.append(MethodDoc(func, self)) + elif isinstance(func, tuple): + self.functions.append(PropertyDoc(func[0], func[1])) self.classes = [] for class_ in classes: self.classes.append(ObjectDoc(class_)) - + + def _is_private_name(self, name): + if name in ('__weakref__', '__repr__','__str__', '__unicode__', + '__getstate__', '__setstate__', '__reduce__', + '__reduce_ex__', '__hash__'): + return True + elif re.match(r'^__.*__$', name): + return False + elif name.startswith('_'): + return True + else: + return False + def _get_inherits(self): for item in self._inherits: if item[0] in self.allobjects: @@ -139,6 +150,12 @@ class FunctionDoc(AbstractDoc): def accept_visitor(self, visitor): visitor.visit_function(self) +class MethodDoc(FunctionDoc): + def __init__(self, func, owner): + super(MethodDoc, self).__init__(func) + if self.name == '__init__' and not self.doc: + self.doc = "Construct a new ``%s``." % owner.name + class PropertyDoc(AbstractDoc): def __init__(self, name, prop): super(PropertyDoc, self).__init__(prop) @@ -147,3 +164,18 @@ class PropertyDoc(AbstractDoc): self.link = name def accept_visitor(self, visitor): visitor.visit_property(self) + +def _method_sort(fna, fnb): + a = getattr(fna, '__name__', None) or fna[0] + b = getattr(fnb, '__name__', None) or fnb[0] + + if a == '__init__': return -1 + if b == '__init__': return 1 + + a_u = a.startswith('__') and a.endswith('__') + b_u = b.startswith('__') and b.endswith('__') + + if a_u and not b_u: return 1 + if b_u and not a_u: return -1 + + return cmp(a, b) diff --git a/doc/build/lib/highlight.py b/doc/build/lib/highlight.py index 1a838408b7..2fff3704ed 100644 --- a/doc/build/lib/highlight.py +++ b/doc/build/lib/highlight.py @@ -170,7 +170,7 @@ class PythonHighlighter(Highlighter): curc = t[2][1] if self.get_style(t[0], t[1]) != curstyle: - if len(tokens): + if tokens: self.colorize([(string.join(tokens, ''), curstyle)]) tokens = [] curstyle = self.get_style(t[0], t[1]) @@ -187,7 +187,7 @@ class PythonHighlighter(Highlighter): curl = t[3][0] # any remaining content to output, output it - if len(tokens): + if tokens: self.colorize([(string.join(tokens, ''), curstyle)]) if trailingspace: diff --git a/doc/build/lib/toc.py b/doc/build/lib/toc.py index dcad5d5c63..b629513456 100644 --- a/doc/build/lib/toc.py +++ b/doc/build/lib/toc.py @@ -44,7 +44,7 @@ class TOCElement(object): self.next = None self.children = [] if parent: - if len(parent.children): + if parent.children: self.previous = parent.children[-1] parent.children[-1].next = self parent.children.append(self) diff --git a/doc/build/read_markdown.py b/doc/build/read_markdown.py index c80589fc2b..53b5b12a85 100644 --- a/doc/build/read_markdown.py +++ b/doc/build/read_markdown.py @@ -25,7 +25,7 @@ def dump_tree(elem, stream): dump_mako_tag(elem, stream) else: if elem.tag != 'html': - if len(elem.attrib): + if elem.attrib: stream.write("<%s %s>" % (elem.tag, " ".join(["%s=%s" % (key, repr(val)) for key, val in elem.attrib.iteritems()]))) else: stream.write("<%s>" % elem.tag) @@ -35,7 +35,8 @@ def dump_tree(elem, stream): dump_tree(child, stream) if child.tail: stream.write(child.tail) - stream.write("" % elem.tag) + if elem.tag != 'html': + stream.write("" % elem.tag) def dump_mako_tag(elem, stream): tag = elem.tag[5:] @@ -143,18 +144,19 @@ def replace_pre_with_mako(tree): # syntax highlighter which uses the tokenize module text = re.sub(r'>>> ', r'">>>" ', text) - sqlre = re.compile(r'{sql}(.*?\n)((?:BEGIN|SELECT|INSERT|DELETE|UPDATE|CREATE|DROP|PRAGMA|DESCRIBE).*?)\n\s*(\n|$)', re.S) + sqlre = re.compile(r'{sql}(.*?)\n((?:PRAGMA|BEGIN|SELECT|INSERT|DELETE|ROLLBACK|COMMIT|UPDATE|CREATE|DROP|PRAGMA|DESCRIBE).*?)\n\s*((?:{stop})|\n|$)', re.S) if sqlre.search(text) is not None: use_sliders = False else: use_sliders = True - text = sqlre.sub(r"""${formatting.poplink()}\1\n<%call expr="formatting.codepopper()">\2\n\n""", text) + text = sqlre.sub(r"""${formatting.poplink()}\1<%call expr="formatting.codepopper()">\2""", text) - sqlre2 = re.compile(r'{opensql}(.*?\n)((?:BEGIN|SELECT|INSERT|DELETE|UPDATE|CREATE|DROP).*?)\n\s*(\n|$)', re.S) - text = sqlre2.sub(r"<%call expr='formatting.poppedcode()' >\1\n\2\n\n", text) + #sqlre2 = re.compile(r'{opensql}(.*?\n)((?:PRAGMA|BEGIN|SELECT|INSERT|DELETE|UPDATE|ROLLBACK|COMMIT|CREATE|DROP).*?)\n\s*((?:{stop})|\n|$)', re.S) + sqlre2 = re.compile(r'{opensql}(.*?)\n?((?:PRAGMA|BEGIN|SELECT|INSERT|DELETE|ROLLBACK|COMMIT|UPDATE|CREATE|DROP|PRAGMA|DESCRIBE).*?)\n\s*((?:{stop})|\n|$)', re.S) + text = sqlre2.sub(r"\1<%call expr='formatting.poppedcode()' >\2\n\n", text) - tag = et.Element("MAKO:formatting.code") + tag = et.Element("MAKO:formatting.code", extension='extension', paged='paged', toc='toc') if code: tag.attrib["syntaxtype"] = repr(code) if title: @@ -170,13 +172,14 @@ def replace_pre_with_mako(tree): parents = get_parent_map(tree) for precode in tree.findall('.//pre/code'): - reg = re.compile(r'\{(python|code)(?: title="(.*?)"){0,1}\}(.*)', re.S) + reg = re.compile(r'\{(python|code|diagram)(?: title="(.*?)"){0,1}\}(.*)', re.S) m = reg.match(precode[0].text.lstrip()) if m: code = m.group(1) title = m.group(2) text = m.group(3) - text = re.sub(r'{(python|code).*?}(\n\s*)?', '', text) + text = re.sub(r'{(python|code|diagram).*?}(\n\s*)?', '', text) + text = re.sub(r'\\\n', r'${r"\\\\" + "\\n\\n"}', text) splice_code_tag(parents[precode], text, code=code, title=title) elif precode.text.lstrip().startswith('>>> '): splice_code_tag(parents[precode], precode.text) @@ -225,6 +228,8 @@ def parse_markdown_files(toc, files): if not os.access(infile, os.F_OK): continue html = markdown.markdown(file(infile).read()) + #foo = file('foo', 'w') + #foo.write(html) tree = et.fromstring("" + html + "") (title, toc_element) = create_toc(inname, tree, toc) safety_code(tree) diff --git a/doc/build/templates/formatting.html b/doc/build/templates/formatting.html index 4068e3adfa..d9a7aa923c 100644 --- a/doc/build/templates/formatting.html +++ b/doc/build/templates/formatting.html @@ -18,7 +18,7 @@ <% content = capture(caller.body) re2 = re.compile(r"'''PYESC(.+?)PYESC'''", re.S) - content = re2.sub(lambda m: m.group(1), content) + content = re2.sub(lambda m: filters.url_unescape(m.group(1)), content) item = toc.get_by_path(path) subsection = item.depth > 1 @@ -36,9 +36,9 @@ % if len(item.children) == 0: % if paged: - back to section top + back to section top % else: - back to section top + back to section top % endif % endif @@ -55,7 +55,7 @@ ${ caller.body() } -<%def name="code(title=None, syntaxtype='mako', html_escape=False, use_sliders=False)"> +<%def name="code(toc, paged, extension, title=None, syntaxtype='mako', html_escape=True, use_sliders=False)"> <% def fix_indent(f): f =string.expandtabs(f, 4) @@ -73,18 +73,27 @@ if whitespace is not None or re.search(r"\w", line) is not None: g += (line + "\n") - - - return g.rstrip() - + else: + g += "\n" + + return g[:-1] #.rstrip() + p = re.compile(r'
(.*?)
', re.S) - def hlight(match): - return "
" + highlight.highlight(fix_indent(match.group(1)), html_escape = html_escape, syntaxtype = syntaxtype) + "
" - try: - content = p.sub(hlight, "
" + capture(caller.body) + "
") - except: - raise "the content is " + str(capture(caller.body)) + def hlight(match): + try: + return "
" + highlight.highlight(fix_indent(match.group(1)), html_escape = html_escape, syntaxtype = syntaxtype) + "
" + except: + print "TEXT IS", fix_indent(match.group(1)) + + def link(match): + return capture(nav.toclink, toc, match.group(2), extension, paged, description=match.group(1)) + + content = re.sub(r'\[(.+?)\]\(rel:(.+?)\)', link, capture(caller.body)) + if syntaxtype != 'diagram': + content = p.sub(hlight, "
" + content + "
") + else: + content = "
" + content + "
" %>
@@ -119,13 +128,13 @@ javascript:togglePopbox('${name}', '${show}', '${hide}') <% href = capture(popboxlink) %> - '''PYESC${nav.link(href=href, text=link, class_="codepoplink")}PYESC''' + '''PYESC${capture(nav.link, href=href, text=link, class_="codepoplink") | u}PYESC''' <%def name="codepopper()" filter="trim"> <% c = capture(caller.body) - c = re.sub(r'\n', '
\n', c.strip()) + c = re.sub(r'\n', '
\n', filters.html_escape(c.strip())) %> <%call expr="popbox(class_='codepop')">${c}
 
@@ -133,7 +142,7 @@ javascript:togglePopbox('${name}', '${show}', '${hide}')
 <%def name="poppedcode()" filter="trim">
     <%
 		c = capture(caller.body)
-		c = re.sub(r'\n', '
\n', c.strip()) + c = re.sub(r'\n', '
\n', filters.html_escape(c.strip())) %>
${c}
 
diff --git a/doc/build/templates/nav.html b/doc/build/templates/nav.html
index 38147d89de..55fc5e8dca 100644
--- a/doc/build/templates/nav.html
+++ b/doc/build/templates/nav.html
@@ -22,6 +22,9 @@
     % if item:
         ${ description }
     % else:
+        <%
+        #raise Exception("Can't find TOC link for '%s'" % path)
+        %>
         ${ description }
     % endif
 
@@ -42,7 +45,7 @@
 
 <%def name="pagenav(item, paged, extension)">
     
@@ -70,4 +73,4 @@
             Next: ${itemlink(item=item.next, paged=paged, anchor=not paged, extension=extension)}
         % endif
     
- \ No newline at end of file + diff --git a/doc/build/templates/pydoc.html b/doc/build/templates/pydoc.html index 34bb5e7bc3..a4f4caf141 100644 --- a/doc/build/templates/pydoc.html +++ b/doc/build/templates/pydoc.html @@ -37,12 +37,14 @@ def formatdocstring(content): <%def name="inline_links(toc, extension, paged)"><% def link(match): (module, desc) = match.group(1,2) - if desc.endswith('()'): + if not desc: + path = "docstrings_" + module + elif desc.endswith('()'): path = "docstrings_" + module + "_modfunc_" + desc[:-2] else: path = "docstrings_" + module + "_" + desc - return capture(nav.toclink, toc=toc, path=path, description=desc, extension=extension, paged=paged) - return lambda content: re.sub(r'\[(.+?)#(.+?)?\]', link, content) + return capture(nav.toclink, toc=toc, path=path, description=desc or None, extension=extension, paged=paged) + return lambda content: re.sub('\[(.+?)#(.*?)\]', link, content) %> <%namespace name="formatting" file="formatting.html"/> diff --git a/doc/build/templates/toc.html b/doc/build/templates/toc.html index f4ea353ec1..0d8f4ca4ca 100644 --- a/doc/build/templates/toc.html +++ b/doc/build/templates/toc.html @@ -6,15 +6,15 @@

Table of Contents

   - (view full table) + (view full table)

${printtoc(root=toc,paged=paged, extension=extension, current=None,children=False,anchor_toplevel=False)}

Table of Contents: Full

   - (view brief table) - + (view brief table) + ${printtoc(root=toc,paged=paged, extension=extension, current=None,children=True,anchor_toplevel=False)} @@ -22,19 +22,23 @@ <%def name="printtoc(root, paged, extension, current=None, children=True, anchor_toplevel=False)"> + % if root.children:
    - % for item in root.children: + % for item in root.children: <% anchor = anchor_toplevel if paged and item.filename != root.filename: anchor = False %> -
  • ${item.description}
  • - - % if children: - ${printtoc(item, current=current, children=True,anchor_toplevel=True, paged=paged, extension=extension)} - % endif - % endfor +
  • ${item.description}
  • + + % if children and item.children: +
  • + ${printtoc(item, current=current, children=True,anchor_toplevel=True, paged=paged, extension=extension)} +
  • + % endif + % endfor
+ % endif diff --git a/doc/build/testdocs.py b/doc/build/testdocs.py index e93c4c5799..998320fb0d 100644 --- a/doc/build/testdocs.py +++ b/doc/build/testdocs.py @@ -1,69 +1,71 @@ -import sys -sys.path = ['../../lib', './lib/'] + sys.path - -import os -import re -import doctest -import sqlalchemy.util as util -import sqlalchemy.logging as salog -import logging - -salog.default_enabled=True -rootlogger = logging.getLogger('sqlalchemy') -rootlogger.setLevel(logging.NOTSET) -class MyStream(object): - def write(self, string): - sys.stdout.write(string) - sys.stdout.flush() - def flush(self): - pass -handler = logging.StreamHandler(MyStream()) -handler.setFormatter(logging.Formatter('%(message)s')) -rootlogger.addHandler(handler) - - -def teststring(s, name, globs=None, verbose=None, report=True, - optionflags=0, extraglobs=None, raise_on_error=False, - parser=doctest.DocTestParser()): - - from doctest import DebugRunner, DocTestRunner, master - - # Assemble the globals. - if globs is None: - globs = {} - else: - globs = globs.copy() - if extraglobs is not None: - globs.update(extraglobs) - - if raise_on_error: - runner = DebugRunner(verbose=verbose, optionflags=optionflags) - else: - runner = DocTestRunner(verbose=verbose, optionflags=optionflags) - - test = parser.get_doctest(s, globs, name, name, 0) - runner.run(test) - - if report: - runner.summarize() - - if master is None: - master = runner - else: - master.merge(runner) - - return runner.failures, runner.tries - -def replace_file(s, newfile): - engine = r"'(sqlite|postgres|mysql):///.*'" - engine = re.compile(engine, re.MULTILINE) - s, n = re.subn(engine, "'sqlite:///" + newfile + "'", s) - if not n: - raise ValueError("Couldn't find suitable create_engine call to replace '%s' in it" % oldfile) - return s - -filename = 'content/tutorial.txt' -s = open(filename).read() -s = replace_file(s, ':memory:') -teststring(s, filename) - +import sys +sys.path = ['../../lib', './lib/'] + sys.path + +import os +import re +import doctest +import sqlalchemy.util as util +import sqlalchemy.logging as salog +import logging + +salog.default_enabled=True +rootlogger = logging.getLogger('sqlalchemy') +rootlogger.setLevel(logging.NOTSET) +class MyStream(object): + def write(self, string): + sys.stdout.write(string) + sys.stdout.flush() + def flush(self): + pass +handler = logging.StreamHandler(MyStream()) +handler.setFormatter(logging.Formatter('%(message)s')) +rootlogger.addHandler(handler) + + +def teststring(s, name, globs=None, verbose=None, report=True, + optionflags=0, extraglobs=None, raise_on_error=False, + parser=doctest.DocTestParser()): + + from doctest import DebugRunner, DocTestRunner, master + + # Assemble the globals. + if globs is None: + globs = {} + else: + globs = globs.copy() + if extraglobs is not None: + globs.update(extraglobs) + + if raise_on_error: + runner = DebugRunner(verbose=verbose, optionflags=optionflags) + else: + runner = DocTestRunner(verbose=verbose, optionflags=optionflags) + + test = parser.get_doctest(s, globs, name, name, 0) + runner.run(test) + + if report: + runner.summarize() + + if master is None: + master = runner + else: + master.merge(runner) + + return runner.failures, runner.tries + +def replace_file(s, newfile): + engine = r"'(sqlite|postgres|mysql):///.*'" + engine = re.compile(engine, re.MULTILINE) + s, n = re.subn(engine, "'sqlite:///" + newfile + "'", s) + if not n: + raise ValueError("Couldn't find suitable create_engine call to replace '%s' in it" % oldfile) + return s + +for filename in ('ormtutorial', 'sqlexpression'): + filename = 'content/%s.txt' % filename + s = open(filename).read() + #s = replace_file(s, ':memory:') + s = re.sub(r'{(?:stop|sql|opensql)}', '', s) + teststring(s, filename) + diff --git a/doc/docs.css b/doc/docs.css index e78533cb83..72cf4d6a82 100644 --- a/doc/docs.css +++ b/doc/docs.css @@ -10,6 +10,11 @@ margin:10px 0px 10px 0px; } +pre { + margin:0px; + padding:0px; +} + .prevnext { padding: 5px 0px 0px 0px; } @@ -72,7 +77,7 @@ h3 { } .sectionL2 { - margin:0px 0px 0px 20px; + margin:0px 0px 0px 0px; line-height: 1.5em; } @@ -141,7 +146,7 @@ h3 { font-size:12px; background-color: #f0f0f0; border: solid 1px #ccc; - padding:2px 2px 2px 10px; + padding:10px; margin: 5px 5px 5px 5px; overflow:auto; } @@ -151,7 +156,7 @@ h3 { font-size:12px; background-color: #f0f0f0; border: solid 1px #ccc; - padding:2px 2px 2px 10px; + padding:10px; /*2px 2px 2px 10px;*/ margin: 5px 5px 5px 5px; line-height:1.2em; } @@ -193,3 +198,11 @@ h3 { background-color: #900; } +@media print { + #nav { display: none; } + #pagecontrol { display: none; } + .topnav .prevnext { display: none; } + .bottomnav { display: none; } + .totoc { display: none; } + .topnav ul li a { text-decoration: none; color: #000; } +} \ No newline at end of file diff --git a/doc/syntaxhighlight.css b/doc/syntaxhighlight.css index 8529ebb5fc..d0ecae2d3b 100644 --- a/doc/syntaxhighlight.css +++ b/doc/syntaxhighlight.css @@ -34,7 +34,7 @@ } .python_operator { - color: #EF0005; + color: #BF0005; } .python_enclosure { diff --git a/examples/adjacencytree/basic_tree.py b/examples/adjacencytree/basic_tree.py index 53bdc82983..65e7d0da8a 100644 --- a/examples/adjacencytree/basic_tree.py +++ b/examples/adjacencytree/basic_tree.py @@ -1,18 +1,18 @@ -"""a basic Adjacency List model tree.""" +"""A basic Adjacency List model tree.""" -from sqlalchemy import * -from sqlalchemy.orm import * -from sqlalchemy.util import OrderedDict +from sqlalchemy import MetaData, Table, Column, Sequence, ForeignKey +from sqlalchemy import Integer, String +from sqlalchemy.orm import create_session, mapper, relation, backref from sqlalchemy.orm.collections import attribute_mapped_collection metadata = MetaData('sqlite:///') metadata.bind.echo = True trees = Table('treenodes', metadata, - Column('node_id', Integer, Sequence('treenode_id_seq',optional=False), primary_key=True), - Column('parent_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True), - Column('node_name', String(50), nullable=False), - ) + Column('id', Integer, Sequence('treenode_id_seq', optional=True), + primary_key=True), + Column('parent_id', Integer, ForeignKey('treenodes.id'), nullable=True), + Column('name', String(50), nullable=False)) class TreeNode(object): @@ -32,19 +32,20 @@ class TreeNode(object): def __str__(self): return self._getstring(0, False) def _getstring(self, level, expand = False): - s = (' ' * level) + "%s (%s,%s, %d)" % (self.name, self.id,self.parent_id,id(self)) + '\n' + s = (' ' * level) + "%s (%s,%s, %d)" % ( + self.name, self.id,self.parent_id,id(self)) + '\n' if expand: - s += ''.join([n._getstring(level+1, True) for n in self.children.values()]) + s += ''.join([n._getstring(level+1, True) + for n in self.children.values()]) return s def print_nodes(self): return self._getstring(0, True) - -mapper(TreeNode, trees, properties=dict( - id=trees.c.node_id, - name=trees.c.node_name, - parent_id=trees.c.parent_node_id, - children=relation(TreeNode, cascade="all", backref=backref("parent", remote_side=[trees.c.node_id]), collection_class=attribute_mapped_collection('name')), -)) + +mapper(TreeNode, trees, properties={ + 'children': relation(TreeNode, cascade="all", + backref=backref("parent", remote_side=[trees.c.id]), + collection_class=attribute_mapped_collection('name'), + lazy=False, join_depth=3)}) print "\n\n\n----------------------------" print "Creating Tree Table:" @@ -113,7 +114,7 @@ print "tree new where node_id=%d:" % nodeid print "----------------------------" session.clear() -t = session.query(TreeNode).select(TreeNode.c.id==nodeid)[0] +t = session.query(TreeNode).filter(TreeNode.c.id==nodeid)[0] print "\n\n\n----------------------------" print "Full Tree:" diff --git a/examples/adjacencytree/byroot_tree.py b/examples/adjacencytree/byroot_tree.py index a61bde8757..e57b11beec 100644 --- a/examples/adjacencytree/byroot_tree.py +++ b/examples/adjacencytree/byroot_tree.py @@ -1,40 +1,48 @@ -"""a more advanced example of basic_tree.py. treenodes can now reference their "root" node, and -introduces a new selection method which selects an entire tree of nodes at once, taking -advantage of a custom MapperExtension to assemble incoming nodes into their correct structure.""" - -from sqlalchemy import * -from sqlalchemy.orm import * +"""A more advanced example of basic_tree.py. + +Treenodes can now reference their "root" node, and introduces a new +selection method which selects an entire tree of nodes at once, taking +advantage of a custom MapperExtension to assemble incoming nodes into their +correct structure. +""" + +from sqlalchemy import MetaData, Table, Column, Sequence, ForeignKey +from sqlalchemy import Integer, String +from sqlalchemy.orm import create_session, mapper, relation, backref +from sqlalchemy.orm import MapperExtension from sqlalchemy.orm.collections import attribute_mapped_collection -engine = create_engine('sqlite:///:memory:', echo=True) -metadata = MetaData(engine) +metadata = MetaData('sqlite:///') +metadata.bind.echo = True -"""create the treenodes table. This is ia basic adjacency list model table. -One additional column, "root_node_id", references a "root node" row and is used -in the 'byroot_tree' example.""" +# Create the `treenodes` table, a basic adjacency list model table. +# One additional column, "root_id", references a "root node" row and is used +# in the 'byroot_tree' example. trees = Table('treenodes', metadata, - Column('node_id', Integer, Sequence('treenode_id_seq',optional=False), primary_key=True), - Column('parent_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True), - Column('root_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True), - Column('node_name', String(50), nullable=False), - Column('data_ident', Integer, ForeignKey('treedata.data_id')) - ) + Column('id', Integer, Sequence('treenode_id_seq', optional=True), + primary_key=True), + Column('parent_id', Integer, ForeignKey('treenodes.id'), nullable=True), + Column('root_id', Integer, ForeignKey('treenodes.id'), nullable=True), + Column('name', String(50), nullable=False), + Column('data_id', Integer, ForeignKey('treedata.data_id'))) treedata = Table( - "treedata", metadata, + "treedata", metadata, Column('data_id', Integer, primary_key=True), - Column('value', String(100), nullable=False) -) + Column('value', String(100), nullable=False)) class TreeNode(object): - """a hierarchical Tree class, which adds the concept of a "root node". The root is - the topmost node in a tree, or in other words a node whose parent ID is NULL. - All child nodes that are decendents of a particular root, as well as a root node itself, - reference this root node. """ - + """A hierarchical Tree class, + + Adds the concept of a "root node". The root is the topmost node in a + tree, or in other words a node whose parent ID is NULL. All child nodes + that are decendents of a particular root, as well as a root node itself, + reference this root node. + """ + def __init__(self, name): self.name = name self.root = self @@ -45,10 +53,10 @@ class TreeNode(object): c._set_root(root) def append(self, node): - if isinstance(node, str): + if isinstance(node, basestring): node = TreeNode(node) node._set_root(self.root) - self.children.append(node) + self.children.set(node) def __repr__(self): return self._getstring(0, False) @@ -57,32 +65,43 @@ class TreeNode(object): return self._getstring(0, False) def _getstring(self, level, expand = False): - s = (' ' * level) + "%s (%s,%s,%s, %d): %s" % (self.name, self.id,self.parent_id,self.root_id, id(self), repr(self.data)) + '\n' + s = "%s%s (%s,%s,%s, %d): %s\n" % ( + (' ' * level), self.name, self.id,self.parent_id, + self.root_id, id(self), repr(self.data)) if expand: - s += ''.join([n._getstring(level+1, True) for n in self.children.values()]) + s += ''.join([n._getstring(level+1, True) + for n in self.children.values()]) return s def print_nodes(self): return self._getstring(0, True) - + class TreeLoader(MapperExtension): def after_insert(self, mapper, connection, instance): - """runs after the insert of a new TreeNode row. The primary key of the row is not determined - until the insert is complete, since most DB's use autoincrementing columns. If this node is - the root node, we will take the new primary key and update it as the value of the node's - "root ID" as well, since its root node is itself.""" + """ + Runs after the insert of a new TreeNode row. The primary key of the + row is not determined until the insert is complete, since most DB's + use autoincrementing columns. If this node is the root node, we + will take the new primary key and update it as the value of the + node's "root ID" as well, since its root node is itself. + """ if instance.root is instance: - connection.execute(mapper.mapped_table.update(TreeNode.c.id==instance.id, values=dict(root_node_id=instance.id))) + connection.execute(mapper.mapped_table.update( + TreeNode.c.id==instance.id, values=dict(root_id=instance.id))) instance.root_id = instance.id def append_result(self, mapper, selectcontext, row, instance, result, **flags): - """runs as results from a SELECT statement are processed, and newly created or already-existing - instances that correspond to each row are appended to result lists. This method will only - append root nodes to the result list, and will attach child nodes to their appropriate parent - node as they arrive from the select results. This allows a SELECT statement which returns - both root and child nodes in one query to return a list of "roots".""" + """ + Runs as results from a SELECT statement are processed, and newly + created or already-existing instances that correspond to each row + are appended to result lists. This method will only append root + nodes to the result list, and will attach child nodes to their + appropriate parent node as they arrive from the select results. + This allows a SELECT statement which returns both root and child + nodes in one query to return a list of "roots". + """ isnew = flags.get('isnew', False) @@ -90,10 +109,11 @@ class TreeLoader(MapperExtension): result.append(instance) else: if isnew or selectcontext.populate_existing: - parentnode = selectcontext.identity_map[mapper.identity_key(instance.parent_id)] - parentnode.children.append(instance) + key = mapper.identity_key_from_primary_key(instance.parent_id) + parentnode = selectcontext.session.identity_map[key] + parentnode.children.set(instance) return False - + class TreeData(object): def __init__(self, value=None): self.id = None @@ -108,24 +128,30 @@ print "----------------------------" metadata.create_all() -# the mapper is created with properties that specify "lazy=None" - this is because we are going -# to handle our own "eager load" of nodes based on root id mapper(TreeNode, trees, properties=dict( - id=trees.c.node_id, - name=trees.c.node_name, - parent_id=trees.c.parent_node_id, - root_id=trees.c.root_node_id, - root=relation(TreeNode, primaryjoin=trees.c.root_node_id==trees.c.node_id, remote_side=trees.c.node_id, lazy=None), - children=relation(TreeNode, - primaryjoin=trees.c.parent_node_id==trees.c.node_id, - lazy=None, - cascade="all", + # 'root' attribute. has a load-only backref '_descendants' that loads + # all nodes with the same root ID eagerly, which are intercepted by the + # TreeLoader extension and populated into the "children" collection. + root=relation(TreeNode, primaryjoin=trees.c.root_id==trees.c.id, + remote_side=trees.c.id, lazy=None, + backref=backref('_descendants', lazy=False, join_depth=1, + primaryjoin=trees.c.root_id==trees.c.id,viewonly=True)), + + # 'children' attribute. collection of immediate child nodes. this is a + # non-loading relation which is populated by the TreeLoader extension. + children=relation(TreeNode, primaryjoin=trees.c.parent_id==trees.c.id, + lazy=None, cascade="all", collection_class=attribute_mapped_collection('name'), - backref=backref('parent', primaryjoin=trees.c.parent_node_id==trees.c.node_id, remote_side=trees.c.node_id) + backref=backref('parent', + primaryjoin=trees.c.parent_id==trees.c.id, + remote_side=trees.c.id) ), + + # 'data' attribute. A collection of secondary objects which also loads + # eagerly. data=relation(TreeData, cascade="all, delete-orphan", lazy=False) - -), extension = TreeLoader()) + +), extension=TreeLoader()) mapper(TreeData, treedata, properties={'id':treedata.c.data_id}) @@ -194,9 +220,12 @@ print "----------------------------" session.clear() -# load some nodes. we do this based on "root id" which will load an entire sub-tree in one pass. -# the MapperExtension will assemble the incoming nodes into a tree structure. -t = session.query(TreeNode).select(TreeNode.c.root_id==nodeid, order_by=[TreeNode.c.id])[0] +# load some nodes. we do this based on "root id" which will load an entire +# sub-tree in one pass. the MapperExtension will assemble the incoming +# nodes into a tree structure. +t = (session.query(TreeNode). + filter(TreeNode.c.root_id==nodeid). + order_by([TreeNode.c.id]))[0] print "\n\n\n----------------------------" print "Full Tree:" diff --git a/examples/association/basic_association.py b/examples/association/basic_association.py index fabfdfa783..8078a2bb9c 100644 --- a/examples/association/basic_association.py +++ b/examples/association/basic_association.py @@ -1,22 +1,26 @@ -"""basic example of using the association object pattern, which is -a richer form of a many-to-many relationship.""" +"""A basic example of using the association object pattern. +The association object pattern is a richer form of a many-to-many +relationship. -# the model will be an ecommerce example. We will have an -# Order, which represents a set of Items purchased by a user. -# each Item has a price. however, the Order must store its own price for -# each Item, representing the price paid by the user for that particular order, which -# is independent of the price on each Item (since those can change). +The model will be an ecommerce example. We will have an Order, which +represents a set of Items purchased by a user. Each Item has a price. +However, the Order must store its own price for each Item, representing +the price paid by the user for that particular order, which is independent +of the price on each Item (since those can change). +""" -from sqlalchemy import * -from sqlalchemy.ext.selectresults import SelectResults +import logging from datetime import datetime -import logging -logging.basicConfig(format='%(message)s') -logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) +from sqlalchemy import * +from sqlalchemy.orm import * -engine = create_engine('sqlite://') +# Uncomment these to watch database activity. +#logging.basicConfig(format='%(message)s') +#logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) + +engine = create_engine('sqlite:///') metadata = MetaData(engine) orders = Table('orders', metadata, @@ -28,13 +32,15 @@ orders = Table('orders', metadata, items = Table('items', metadata, Column('item_id', Integer, primary_key=True), Column('description', String(30), nullable=False), - Column('price', Float, nullable=False) + Column('price', Numeric(8, 2), nullable=False) ) orderitems = Table('orderitems', metadata, - Column('order_id', Integer, ForeignKey('orders.order_id'), primary_key=True), - Column('item_id', Integer, ForeignKey('items.item_id'), primary_key=True), - Column('price', Float, nullable=False) + Column('order_id', Integer, ForeignKey('orders.order_id'), + primary_key=True), + Column('item_id', Integer, ForeignKey('items.item_id'), + primary_key=True), + Column('price', Numeric(8, 2), nullable=False) ) metadata.create_all() @@ -46,6 +52,8 @@ class Item(object): def __init__(self, description, price): self.description = description self.price = price + def __repr__(self): + return 'Item(%s, %s)' % (repr(self.description), repr(self.price)) class OrderItem(object): def __init__(self, item, price=None): @@ -53,11 +61,12 @@ class OrderItem(object): self.price = price or item.price mapper(Order, orders, properties={ - 'items':relation(OrderItem, cascade="all, delete-orphan", lazy=False) + 'order_items': relation(OrderItem, cascade="all, delete-orphan", + backref='order') }) mapper(Item, items) mapper(OrderItem, orderitems, properties={ - 'item':relation(Item, lazy=False) + 'item': relation(Item, lazy=False) }) session = create_session() @@ -71,34 +80,28 @@ session.flush() # function to return items from the DB def item(name): - return session.query(Item).get_by(description=name) + return session.query(Item).filter_by(description=name).one() # create an order order = Order('john smith') # add three OrderItem associations to the Order and save -order.items.append(OrderItem(item('SA Mug'))) -order.items.append(OrderItem(item('MySQL Crowbar'), 10.99)) -order.items.append(OrderItem(item('SA Hat'))) +order.order_items.append(OrderItem(item('SA Mug'))) +order.order_items.append(OrderItem(item('MySQL Crowbar'), 10.99)) +order.order_items.append(OrderItem(item('SA Hat'))) session.save(order) session.flush() session.clear() # query the order, print items -order = session.query(Order).get_by(customer_name='john smith') -print [(item.item.description, item.price) for item in order.items] +order = session.query(Order).filter_by(customer_name='john smith').one() +print [(order_item.item.description, order_item.price) + for order_item in order.order_items] # print customers who bought 'MySQL Crowbar' on sale -result = SelectResults(session.query(Order)).join_to('item').select(and_(items.c.description=='MySQL Crowbar', items.c.price>orderitems.c.price)) -print [order.customer_name for order in result] - - - - - - - - - +q = session.query(Order).join(['order_items', 'item']) +q = q.filter(and_(Item.description == 'MySQL Crowbar', + Item.price > OrderItem.price)) +print [order.customer_name for order in q] diff --git a/examples/association/proxied_association.py b/examples/association/proxied_association.py index 2dd60158b9..f7dd45c4ae 100644 --- a/examples/association/proxied_association.py +++ b/examples/association/proxied_association.py @@ -2,6 +2,7 @@ the usage of the associationproxy extension.""" from sqlalchemy import * +from sqlalchemy.orm import * from sqlalchemy.ext.selectresults import SelectResults from sqlalchemy.ext.associationproxy import AssociationProxy from datetime import datetime @@ -66,7 +67,7 @@ session.flush() # function to return items def item(name): - return session.query(Item).get_by(description=name) + return session.query(Item).filter_by(description=name).one() # create an order order = Order('john smith') @@ -88,7 +89,7 @@ session.flush() session.clear() # query the order, print items -order = session.query(Order).get_by(customer_name='john smith') +order = session.query(Order).filter_by(customer_name='john smith').one() # print items based on the OrderItem collection directly print [(item.item.description, item.price) for item in order.itemassociations] @@ -97,11 +98,11 @@ print [(item.item.description, item.price) for item in order.itemassociations] print [(item.description, item.price) for item in order.items] # print customers who bought 'MySQL Crowbar' on sale -result = session.query(Order).join('item').filter(and_(items.c.description=='MySQL Crowbar', items.c.price>orderitems.c.price)) +result = session.query(Order).join(['itemassociations', 'item']).filter(and_(Item.description=='MySQL Crowbar', Item.price>OrderItem.price)) print [order.customer_name for order in result] # print customers who got the special T-shirt discount -result = session.query(Order).join('item').filter(and_(items.c.description=='SA T-Shirt', items.c.price>orderitems.c.price)) +result = session.query(Order).join(['itemassociations', 'item']).filter(and_(Item.description=='SA T-Shirt', Item.price>OrderItem.price)) print [order.customer_name for order in result] diff --git a/examples/collections/large_collection.py b/examples/collections/large_collection.py index 3c53db121c..203aa6d230 100644 --- a/examples/collections/large_collection.py +++ b/examples/collections/large_collection.py @@ -1,6 +1,11 @@ -"""illlustrates techniques for dealing with very large collections""" +"""illlustrates techniques for dealing with very large collections. + +Also see the docs regarding the new "dynamic" relation option, which +presents a more refined version of some of these patterns. +""" from sqlalchemy import * +from sqlalchemy.orm import * meta = MetaData('sqlite://') meta.bind.echo = True @@ -60,7 +65,7 @@ sess.clear() # reload. load the org and some child members print "-------------------------\nload subset of members" org = sess.query(Organization).get(org.org_id) -members = org.member_query.filter_by(member_table.c.name.like('%member t%')).list() +members = org.member_query.filter(member_table.c.name.like('%member t%')).all() print members sess.clear() diff --git a/examples/dynamic_dict/dynamic_dict.py b/examples/dynamic_dict/dynamic_dict.py new file mode 100644 index 0000000000..682def78c3 --- /dev/null +++ b/examples/dynamic_dict/dynamic_dict.py @@ -0,0 +1,83 @@ +"""Illustrates how to place a dictionary-like facade on top of a dynamic_loader, so +that dictionary operations (assuming simple string keys) can operate upon a large +collection without loading the full collection at once. + +This is something that may eventually be added as a feature to dynamic_loader() itself. + +Similar approaches could be taken towards sets and dictionaries with non-string keys +although the hash policy of the members would need to be distilled into a filter() criterion. + +""" + +class MyProxyDict(object): + def __init__(self, parent, collection_name, keyname): + self.parent = parent + self.collection_name = collection_name + self.keyname = keyname + + def collection(self): + return getattr(self.parent, self.collection_name) + collection = property(collection) + + def keys(self): + # this can be improved to not query all columns + return [getattr(x, self.keyname) for x in self.collection.all()] + + def __getitem__(self, key): + x = self.collection.filter_by(**{self.keyname:key}).first() + if x: + return x + else: + raise KeyError(key) + + def __setitem__(self, key, value): + try: + existing = self[key] + self.collection.remove(existing) + except KeyError: + pass + self.collection.append(value) + +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import * +from sqlalchemy.orm import * + +Base = declarative_base(engine=create_engine('sqlite://')) + +class MyParent(Base): + __tablename__ = 'parent' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + _collection = dynamic_loader("MyChild", cascade="all, delete-orphan") + + def child_map(self): + return MyProxyDict(self, '_collection', 'key') + child_map = property(child_map) + +class MyChild(Base): + __tablename__ = 'child' + id = Column(Integer, primary_key=True) + key = Column(String(50)) + parent_id = Column(Integer, ForeignKey('parent.id')) + + +Base.metadata.create_all() + +sess = create_session(autoflush=True, transactional=True) + +p1 = MyParent(name='p1') +sess.save(p1) + +p1.child_map['k1'] = k1 = MyChild(key='k1') +p1.child_map['k2'] = k2 = MyChild(key='k2') + + +assert p1.child_map.keys() == ['k1', 'k2'] + +assert p1.child_map['k1'] is k1 + +p1.child_map['k2'] = k2b = MyChild(key='k2') +assert p1.child_map['k2'] is k2b + +assert sess.query(MyChild).all() == [k1, k2b] + diff --git a/examples/elementtree/adjacency_list.py b/examples/elementtree/adjacency_list.py index 204662f561..706cc88a07 100644 --- a/examples/elementtree/adjacency_list.py +++ b/examples/elementtree/adjacency_list.py @@ -26,7 +26,7 @@ from elementtree import ElementTree from elementtree.ElementTree import Element, SubElement meta = MetaData() -meta.engine = 'sqlite://' +meta.bind = 'sqlite://' ################################# PART II - Table Metadata ########################################### @@ -129,11 +129,11 @@ class ElementTreeMarshal(object): def __set__(self, document, element): def traverse(node): n = _Node() - n.tag = node.tag - n.text = node.text - n.tail = node.tail + n.tag = unicode(node.tag) + n.text = unicode(node.text) + n.tail = unicode(node.tail) n.children = [traverse(n2) for n2 in node] - n.attributes = [_Attribute(k, v) for k, v in node.attrib.iteritems()] + n.attributes = [_Attribute(unicode(k), unicode(v)) for k, v in node.attrib.iteritems()] return n document._root = traverse(element.getroot()) @@ -174,12 +174,10 @@ print document ############################################ PART VI - Searching for Paths ####################################### # manually search for a document which contains "/somefile/header/field1:hi" -print "\nManual search for /somefile/header/field1=='hi':", line -n1 = elements.alias('n1') -n2 = elements.alias('n2') -n3 = elements.alias('n3') -j = documents.join(n1).join(n2, n1.c.element_id==n2.c.parent_id).join(n3, n2.c.element_id==n3.c.parent_id) -d = session.query(Document).select_from(j).filter(n1.c.tag=='somefile').filter(n2.c.tag=='header').filter(and_(n3.c.tag=='field1', n3.c.text=='hi')).one() +d = session.query(Document).join('_root', aliased=True).filter(_Node.tag==u'somefile').\ + join('children', aliased=True, from_joinpoint=True).filter(_Node.tag==u'header').\ + join('children', aliased=True, from_joinpoint=True).filter(and_(_Node.tag==u'field1', _Node.text==u'hi')).\ + one() print d # generalize the above approach into an extremely impoverished xpath function: @@ -187,28 +185,23 @@ def find_document(path, compareto): j = documents prev_elements = None query = session.query(Document) + attribute = '_root' for i, match in enumerate(re.finditer(r'/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?', path)): (token, attrname, attrvalue) = match.group(1, 2, 3) - a = elements.alias("n%d" % i) - query = query.filter(a.c.tag==token) + query = query.join(attribute, aliased=True, from_joinpoint=True).filter(_Node.tag==token) + attribute = 'children' if attrname: - attr_alias = attributes.alias('a%d' % i) if attrvalue: - query = query.filter(and_(a.c.element_id==attr_alias.c.element_id, attr_alias.c.name==attrname, attr_alias.c.value==attrvalue)) + query = query.join('attributes', aliased=True, from_joinpoint=True).filter(and_(_Attribute.name==attrname, _Attribute.value==attrvalue)) else: - query = query.filter(and_(a.c.element_id==attr_alias.c.element_id, attr_alias.c.name==attrname)) - if prev_elements is not None: - j = j.join(a, prev_elements.c.element_id==a.c.parent_id) - else: - j = j.join(a) - prev_elements = a - return query.options(lazyload('_root')).select_from(j).filter(prev_elements.c.text==compareto).all() + query = query.join('attributes', aliased=True, from_joinpoint=True).filter(_Attribute.name==attrname) + return query.options(lazyload('_root')).filter(_Node.text==compareto).all() for path, compareto in ( - ('/somefile/header/field1', 'hi'), - ('/somefile/field1', 'hi'), - ('/somefile/header/field2', 'there'), - ('/somefile/header/field2[@attr=foo]', 'there') + (u'/somefile/header/field1', u'hi'), + (u'/somefile/field1', u'hi'), + (u'/somefile/header/field2', u'there'), + (u'/somefile/header/field2[@attr=foo]', u'there') ): print "\nDocuments containing '%s=%s':" % (path, compareto), line print [d.filename for d in find_document(path, compareto)] diff --git a/examples/elementtree/optimized_al.py b/examples/elementtree/optimized_al.py index 17b6489de0..316f17c679 100644 --- a/examples/elementtree/optimized_al.py +++ b/examples/elementtree/optimized_al.py @@ -18,14 +18,14 @@ logging.basicConfig() #logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) # uncomment to show SQL statements and result sets -logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG) +#logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG) from elementtree import ElementTree from elementtree.ElementTree import Element, SubElement meta = MetaData() -meta.engine = 'sqlite://' +meta.bind = 'sqlite://' ################################# PART II - Table Metadata ########################################### @@ -94,7 +94,7 @@ mapper(Document, documents, properties={ }) # the _Node objects change the way they load so that a list of _Nodes will organize -# themselves hierarchically using the HierarchicalLoader. this depends on the ordering of +# themselves hierarchically using the ElementTreeMarshal. this depends on the ordering of # nodes being hierarchical as well; relation() always applies at least ROWID/primary key # ordering to rows which will suffice. mapper(_Node, elements, properties={ @@ -137,12 +137,12 @@ class ElementTreeMarshal(object): def __set__(self, document, element): def traverse(node): n = _Node() - n.tag = node.tag - n.text = node.text - n.tail = node.tail + n.tag = unicode(node.tag) + n.text = unicode(node.text) + n.tail = unicode(node.tail) document._nodes.append(n) n.children = [traverse(n2) for n2 in node] - n.attributes = [_Attribute(k, v) for k, v in node.attrib.iteritems()] + n.attributes = [_Attribute(unicode(k), unicode(v)) for k, v in node.attrib.iteritems()] return n traverse(element.getroot()) @@ -184,11 +184,10 @@ print document # manually search for a document which contains "/somefile/header/field1:hi" print "\nManual search for /somefile/header/field1=='hi':", line -n1 = elements.alias('n1') -n2 = elements.alias('n2') -n3 = elements.alias('n3') -j = documents.join(n1).join(n2, n1.c.element_id==n2.c.parent_id).join(n3, n2.c.element_id==n3.c.parent_id) -d = session.query(Document).select_from(j).filter(n1.c.tag=='somefile').filter(n2.c.tag=='header').filter(and_(n3.c.tag=='field1', n3.c.text=='hi')).one() +d = session.query(Document).join('_nodes', aliased=True).filter(and_(_Node.parent_id==None, _Node.tag==u'somefile')).\ + join('children', aliased=True, from_joinpoint=True).filter(_Node.tag==u'header').\ + join('children', aliased=True, from_joinpoint=True).filter(and_(_Node.tag==u'field1', _Node.text==u'hi')).\ + one() print d # generalize the above approach into an extremely impoverished xpath function: @@ -196,28 +195,28 @@ def find_document(path, compareto): j = documents prev_elements = None query = session.query(Document) + first = True for i, match in enumerate(re.finditer(r'/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?', path)): (token, attrname, attrvalue) = match.group(1, 2, 3) - a = elements.alias("n%d" % i) - query = query.filter(a.c.tag==token) + if first: + query = query.join('_nodes', aliased=True).filter(_Node.parent_id==None) + first = False + else: + query = query.join('children', aliased=True, from_joinpoint=True) + query = query.filter(_Node.tag==token) if attrname: - attr_alias = attributes.alias('a%d' % i) + query = query.join('attributes', aliased=True, from_joinpoint=True) if attrvalue: - query = query.filter(and_(a.c.element_id==attr_alias.c.element_id, attr_alias.c.name==attrname, attr_alias.c.value==attrvalue)) + query = query.filter(and_(_Attribute.name==attrname, _Attribute.value==attrvalue)) else: - query = query.filter(and_(a.c.element_id==attr_alias.c.element_id, attr_alias.c.name==attrname)) - if prev_elements is not None: - j = j.join(a, prev_elements.c.element_id==a.c.parent_id) - else: - j = j.join(a) - prev_elements = a - return query.options(lazyload('_nodes')).select_from(j).filter(prev_elements.c.text==compareto).all() + query = query.filter(_Attribute.name==attrname) + return query.options(lazyload('_nodes')).filter(_Node.text==compareto).all() for path, compareto in ( - ('/somefile/header/field1', 'hi'), - ('/somefile/field1', 'hi'), - ('/somefile/header/field2', 'there'), - ('/somefile/header/field2[@attr=foo]', 'there') + (u'/somefile/header/field1', u'hi'), + (u'/somefile/field1', u'hi'), + (u'/somefile/header/field2', u'there'), + (u'/somefile/header/field2[@attr=foo]', u'there') ): print "\nDocuments containing '%s=%s':" % (path, compareto), line print [d.filename for d in find_document(path, compareto)] diff --git a/examples/elementtree/pickle.py b/examples/elementtree/pickle.py index 443ca85c3e..e7cd86984e 100644 --- a/examples/elementtree/pickle.py +++ b/examples/elementtree/pickle.py @@ -22,8 +22,8 @@ logging.basicConfig() from elementtree import ElementTree -meta = MetaData() -meta.engine = 'sqlite://' +engine = create_engine('sqlite://') +meta = MetaData(engine) # stores a top level record of an XML document. # the "element" column will store the ElementTree document as a BLOB. diff --git a/examples/graphs/graph1.py b/examples/graphs/graph1.py index c2eec44f9f..8188d7c870 100644 --- a/examples/graphs/graph1.py +++ b/examples/graphs/graph1.py @@ -1,6 +1,7 @@ """a directed graph example.""" from sqlalchemy import * +from sqlalchemy.orm import * import logging logging.basicConfig() diff --git a/examples/pickle/custom_pickler.py b/examples/pickle/custom_pickler.py index b45e16e7c6..1c88c88e82 100644 --- a/examples/pickle/custom_pickler.py +++ b/examples/pickle/custom_pickler.py @@ -1,6 +1,7 @@ """illustrates one way to use a custom pickler that is session-aware.""" from sqlalchemy import * +from sqlalchemy.orm import * from sqlalchemy.orm.session import object_session from cStringIO import StringIO from pickle import Pickler, Unpickler @@ -10,15 +11,15 @@ meta = MetaData('sqlite://') meta.bind.echo = True class MyExt(MapperExtension): - def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew): + def populate_instance(self, mapper, selectcontext, row, instance, **flags): MyPickler.sessions.current = selectcontext.session - return EXT_PASS + return EXT_CONTINUE def before_insert(self, mapper, connection, instance): MyPickler.sessions.current = object_session(instance) - return EXT_PASS + return EXT_CONTINUE def before_update(self, mapper, connection, instance): MyPickler.sessions.current = object_session(instance) - return EXT_PASS + return EXT_CONTINUE class MyPickler(object): sessions = threading.local() diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py index e95b978aec..0a9e992a32 100644 --- a/examples/sharding/attribute_shard.py +++ b/examples/sharding/attribute_shard.py @@ -21,8 +21,9 @@ To set up a sharding system, you need: from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.shard import ShardedSession -from sqlalchemy.sql import ColumnOperators -import datetime, operator +from sqlalchemy.sql import operators +from sqlalchemy import sql +import datetime # step 2. databases echo = True @@ -34,13 +35,15 @@ db4 = create_engine('sqlite:///shard4.db', echo=echo) # step 3. create session function. this binds the shard ids # to databases within a ShardedSession and returns it. -def create_session(): - s = ShardedSession(shard_chooser, id_chooser, query_chooser) - s.bind_shard('north_america', db1) - s.bind_shard('asia', db2) - s.bind_shard('europe', db3) - s.bind_shard('south_america', db4) - return s +create_session = sessionmaker(class_=ShardedSession) + +create_session.configure(shards={ + 'north_america':db1, + 'asia':db2, + 'europe':db3, + 'south_america':db4 +}) + # step 4. table setup. meta = MetaData() @@ -105,7 +108,7 @@ shard_lookup = { # note that we need to define conditions for # the WeatherLocation class, as well as our secondary Report class which will # point back to its WeatherLocation via its 'location' attribute. -def shard_chooser(mapper, instance): +def shard_chooser(mapper, instance, clause=None): if isinstance(instance, WeatherLocation): return shard_lookup[instance.continent] else: @@ -116,7 +119,7 @@ def shard_chooser(mapper, instance): # pk so we just return all shard ids. often, youd want to do some # kind of round-robin strategy here so that requests are evenly # distributed among DBs -def id_chooser(ident): +def id_chooser(query, ident): return ['north_america', 'asia', 'europe', 'south_america'] # query_chooser. this also returns a list of shard ids, which can @@ -131,9 +134,9 @@ def query_chooser(query): class FindContinent(sql.ClauseVisitor): def visit_binary(self, binary): if binary.left is weather_locations.c.continent: - if binary.operator == operator.eq: + if binary.operator == operators.eq: ids.append(shard_lookup[binary.right.value]) - elif binary.operator == ColumnOperators.in_op: + elif binary.operator == operators.in_op: for bind in binary.right.clauses: ids.append(shard_lookup[bind.value]) @@ -143,6 +146,9 @@ def query_chooser(query): else: return ids +# further configure create_session to use these functions +create_session.configure(shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) + # step 6. mapped classes. class WeatherLocation(object): def __init__(self, continent, city): diff --git a/examples/vertical/dictlike-polymorphic.py b/examples/vertical/dictlike-polymorphic.py new file mode 100644 index 0000000000..4065337c2e --- /dev/null +++ b/examples/vertical/dictlike-polymorphic.py @@ -0,0 +1,265 @@ +"""Mapping a polymorphic-valued vertical table as a dictionary. + +This example illustrates accessing and modifying a "vertical" (or +"properties", or pivoted) table via a dict-like interface. The 'dictlike.py' +example explains the basics of vertical tables and the general approach. This +example adds a twist- the vertical table holds several "value" columns, one +for each type of data that can be stored. For example:: + + Table('properties', metadata + Column('owner_id', Integer, ForeignKey('owner.id'), + primary_key=True), + Column('key', UnicodeText), + Column('type', Unicode(16)), + Column('int_value', Integer), + Column('char_value', UnicodeText), + Column('bool_value', Boolean), + Column('decimal_value', Numeric(10,2))) + +For any given properties row, the value of the 'type' column will point to the +'_value' column active for that row. + +This example approach uses exactly the same dict mapping approach as the +'dictlike' example. It only differs in the mapping for vertical rows. Here, +we'll use a Python @property to build a smart '.value' attribute that wraps up +reading and writing those various '_value' columns and keeps the '.type' up to +date. + +Note: Something much like 'comparable_property' is slated for inclusion in a + future version of SQLAlchemy. +""" + +from sqlalchemy.orm.interfaces import PropComparator, MapperProperty +from sqlalchemy.orm import session as sessionlib, comparable_property + +# Using the VerticalPropertyDictMixin from the base example +from dictlike import VerticalPropertyDictMixin + +class PolymorphicVerticalProperty(object): + """A key/value pair with polymorphic value storage. + + Supplies a smart 'value' attribute that provides convenient read/write + access to the row's current value without the caller needing to worry + about the 'type' attribute or multiple columns. + + The 'value' attribute can also be used for basic comparisons in queries, + allowing the row's logical value to be compared without foreknowledge of + which column it might be in. This is not going to be a very efficient + operation on the database side, but it is possible. If you're mapping to + an existing database and you have some rows with a value of str('1') and + others of int(1), then this could be useful. + + Subclasses must provide a 'type_map' class attribute with the following + form:: + + type_map = { + : ('type column value', 'column name'), + # ... + } + + For example,:: + + type_map = { + int: ('integer', 'integer_value'), + str: ('varchar', 'varchar_value'), + } + + Would indicate that a Python int value should be stored in the + 'integer_value' column and the .type set to 'integer'. Conversely, if the + value of '.type' is 'integer, then the 'integer_value' column is consulted + for the current value. + """ + + type_map = { + type(None): (None, None), + } + + class Comparator(PropComparator): + """A comparator for .value, builds a polymorphic comparison via CASE. + + Optional. If desired, install it as a comparator in the mapping:: + + mapper(..., properties={ + 'value': comparable_property(PolymorphicVerticalProperty.Comparator, + PolymorphicVerticalProperty.value) + }) + """ + + def _case(self): + cls = self.prop.parent.class_ + whens = [(text("'%s'" % p[0]), getattr(cls, p[1])) + for p in cls.type_map.values() + if p[1] is not None] + return case(whens, cls.type, null()) + def __eq__(self, other): + return cast(self._case(), String) == cast(other, String) + def __ne__(self, other): + return cast(self._case(), String) != cast(other, String) + + def __init__(self, key, value=None): + self.key = key + self.value = value + + def _get_value(self): + for discriminator, field in self.type_map.values(): + if self.type == discriminator: + return getattr(self, field) + return None + + def _set_value(self, value): + py_type = type(value) + if py_type not in self.type_map: + raise TypeError(py_type) + + for field_type in self.type_map: + discriminator, field = self.type_map[field_type] + field_value = None + if py_type == field_type: + self.type = discriminator + field_value = value + if field is not None: + setattr(self, field, field_value) + + def _del_value(self): + self._set_value(None) + + value = property(_get_value, _set_value, _del_value, doc= + """The logical value of this property.""") + + def __repr__(self): + return '<%s %r=%r>' % (self.__class__.__name__, self.key, self.value) + + +if __name__ == '__main__': + from sqlalchemy import * + from sqlalchemy.orm import mapper, relation, create_session + from sqlalchemy.orm.collections import attribute_mapped_collection + + metadata = MetaData() + + animals = Table('animal', metadata, + Column('id', Integer, primary_key=True), + Column('name', Unicode(100))) + + chars = Table('facts', metadata, + Column('animal_id', Integer, ForeignKey('animal.id'), + primary_key=True), + Column('key', Unicode(64), primary_key=True), + Column('type', Unicode(16), default=None), + Column('int_value', Integer, default=None), + Column('char_value', UnicodeText, default=None), + Column('boolean_value', Boolean, default=None)) + + class AnimalFact(PolymorphicVerticalProperty): + type_map = { + int: (u'integer', 'int_value'), + unicode: (u'char', 'char_value'), + bool: (u'boolean', 'boolean_value'), + type(None): (None, None), + } + + class Animal(VerticalPropertyDictMixin): + """An animal. + + Animal facts are available via the 'facts' property or by using + dict-like accessors on an Animal instance:: + + cat['color'] = 'calico' + # or, equivalently: + cat.facts['color'] = AnimalFact('color', 'calico') + """ + + _property_type = AnimalFact + _property_mapping = 'facts' + + def __init__(self, name): + self.name = name + + def __repr__(self): + return '<%s %r>' % (self.__class__.__name__, self.name) + + + mapper(Animal, animals, properties={ + 'facts': relation( + AnimalFact, backref='animal', + collection_class=attribute_mapped_collection('key')), + }) + + mapper(AnimalFact, chars, properties={ + 'value': comparable_property(AnimalFact.Comparator, AnimalFact.value) + }) + + metadata.bind = 'sqlite:///' + metadata.create_all() + session = create_session() + + stoat = Animal(u'stoat') + stoat[u'color'] = u'red' + stoat[u'cuteness'] = 7 + stoat[u'weasel-like'] = True + + session.save(stoat) + session.flush() + session.clear() + + critter = session.query(Animal).filter(Animal.name == u'stoat').one() + print critter[u'color'] + print critter[u'cuteness'] + + print "changing cuteness value and type:" + critter[u'cuteness'] = u'very cute' + + metadata.bind.echo = True + session.flush() + metadata.bind.echo = False + + marten = Animal(u'marten') + marten[u'cuteness'] = 5 + marten[u'weasel-like'] = True + marten[u'poisonous'] = False + session.save(marten) + + shrew = Animal(u'shrew') + shrew[u'cuteness'] = 5 + shrew[u'weasel-like'] = False + shrew[u'poisonous'] = True + + session.save(shrew) + session.flush() + + q = (session.query(Animal). + filter(Animal.facts.any( + and_(AnimalFact.key == u'weasel-like', + AnimalFact.value == True)))) + print 'weasel-like animals', q.all() + + # Save some typing by wrapping that up in a function: + with_characteristic = lambda key, value: and_(AnimalFact.key == key, + AnimalFact.value == value) + + q = (session.query(Animal). + filter(Animal.facts.any( + with_characteristic(u'weasel-like', True)))) + print 'weasel-like animals again', q.all() + + q = (session.query(Animal). + filter(Animal.facts.any(with_characteristic(u'poisonous', False)))) + print 'animals with poisonous=False', q.all() + + q = (session.query(Animal). + filter(or_(Animal.facts.any( + with_characteristic(u'poisonous', False)), + not_(Animal.facts.any(AnimalFact.key == u'poisonous'))))) + print 'non-poisonous animals', q.all() + + q = (session.query(Animal). + filter(Animal.facts.any(AnimalFact.value == 5))) + print 'any animal with a .value of 5', q.all() + + # Facts can be queried as well. + q = (session.query(AnimalFact). + filter(with_characteristic(u'cuteness', u'very cute'))) + print q.all() + + + metadata.drop_all() diff --git a/examples/vertical/dictlike.py b/examples/vertical/dictlike.py new file mode 100644 index 0000000000..5f478d7d05 --- /dev/null +++ b/examples/vertical/dictlike.py @@ -0,0 +1,247 @@ +"""Mapping a vertical table as a dictionary. + +This example illustrates accessing and modifying a "vertical" (or +"properties", or pivoted) table via a dict-like interface. These are tables +that store free-form object properties as rows instead of columns. For +example, instead of:: + + # A regular ("horizontal") table has columns for 'species' and 'size' + Table('animal', metadata, + Column('id', Integer, primary_key=True), + Column('species', Unicode), + Column('size', Unicode)) + +A vertical table models this as two tables: one table for the base or parent +entity, and another related table holding key/value pairs:: + + Table('animal', metadata, + Column('id', Integer, primary_key=True)) + + # The properties table will have one row for a 'species' value, and + # another row for the 'size' value. + Table('properties', metadata + Column('animal_id', Integer, ForeignKey('animal.id'), + primary_key=True), + Column('key', UnicodeText), + Column('value', UnicodeText)) + +Because the key/value pairs in a vertical scheme are not fixed in advance, +accessing them like a Python dict can be very convenient. The example below +can be used with many common vertical schemas as-is or with minor adaptations. +""" + +class VerticalProperty(object): + """A key/value pair. + + This class models rows in the vertical table. + """ + + def __init__(self, key, value): + self.key = key + self.value = value + + def __repr__(self): + return '<%s %r=%r>' % (self.__class__.__name__, self.key, self.value) + + +class VerticalPropertyDictMixin(object): + """Adds obj[key] access to a mapped class. + + This is a mixin class. It can be inherited from directly, or included + with multiple inheritence. + + Classes using this mixin must define two class properties:: + + _property_type: + The mapped type of the vertical key/value pair instances. Will be + invoked with two positional arugments: key, value + + _property_mapping: + A string, the name of the Python attribute holding a dict-based + relation of _property_type instances. + + Using the VerticalProperty class above as an example,:: + + class MyObj(VerticalPropertyDictMixin): + _property_type = VerticalProperty + _property_mapping = 'props' + + mapper(MyObj, sometable, properties={ + 'props': relation(VerticalProperty, + collection_class=attribute_mapped_collection('key'))}) + + Dict-like access to MyObj is proxied through to the 'props' relation:: + + myobj['key'] = 'value' + # ...is shorthand for: + myobj.props['key'] = VerticalProperty('key', 'value') + + myobj['key'] = 'updated value'] + # ...is shorthand for: + myobj.props['key'].value = 'updated value' + + print myobj['key'] + # ...is shorthand for: + print myobj.props['key'].value + + """ + + _property_type = VerticalProperty + _property_mapping = None + + __map = property(lambda self: getattr(self, self._property_mapping)) + + def __getitem__(self, key): + return self.__map[key].value + + def __setitem__(self, key, value): + property = self.__map.get(key, None) + if property is None: + self.__map[key] = self._property_type(key, value) + else: + property.value = value + + def __delitem__(self, key): + del self.__map[key] + + def __contains__(self, key): + return key in self.__map + + # Implement other dict methods to taste. Here are some examples: + def keys(self): + return self.__map.keys() + + def values(self): + return [prop.value for prop in self.__map.values()] + + def items(self): + return [(key, prop.value) for key, prop in self.__map.items()] + + def __iter__(self): + return iter(self.keys()) + + +if __name__ == '__main__': + from sqlalchemy import * + from sqlalchemy.orm import mapper, relation, create_session + from sqlalchemy.orm.collections import attribute_mapped_collection + + metadata = MetaData() + + # Here we have named animals, and a collection of facts about them. + animals = Table('animal', metadata, + Column('id', Integer, primary_key=True), + Column('name', Unicode(100))) + + facts = Table('facts', metadata, + Column('animal_id', Integer, ForeignKey('animal.id'), + primary_key=True), + Column('key', Unicode(64), primary_key=True), + Column('value', UnicodeText, default=None),) + + class AnimalFact(VerticalProperty): + """A fact about an animal.""" + + class Animal(VerticalPropertyDictMixin): + """An animal. + + Animal facts are available via the 'facts' property or by using + dict-like accessors on an Animal instance:: + + cat['color'] = 'calico' + # or, equivalently: + cat.facts['color'] = AnimalFact('color', 'calico') + """ + + _property_type = AnimalFact + _property_mapping = 'facts' + + def __init__(self, name): + self.name = name + + def __repr__(self): + return '<%s %r>' % (self.__class__.__name__, self.name) + + + mapper(Animal, animals, properties={ + 'facts': relation( + AnimalFact, backref='animal', + collection_class=attribute_mapped_collection('key')), + }) + mapper(AnimalFact, facts) + + + metadata.bind = 'sqlite:///' + metadata.create_all() + session = create_session() + + stoat = Animal(u'stoat') + stoat[u'color'] = u'reddish' + stoat[u'cuteness'] = u'somewhat' + + # dict-like assignment transparently creates entries in the + # stoat.facts collection: + print stoat.facts[u'color'] + + session.save(stoat) + session.flush() + session.clear() + + critter = session.query(Animal).filter(Animal.name == u'stoat').one() + print critter[u'color'] + print critter[u'cuteness'] + + critter[u'cuteness'] = u'very' + + print 'changing cuteness:' + metadata.bind.echo = True + session.flush() + metadata.bind.echo = False + + marten = Animal(u'marten') + marten[u'color'] = u'brown' + marten[u'cuteness'] = u'somewhat' + session.save(marten) + + shrew = Animal(u'shrew') + shrew[u'cuteness'] = u'somewhat' + shrew[u'poisonous-part'] = u'saliva' + session.save(shrew) + + loris = Animal(u'slow loris') + loris[u'cuteness'] = u'fairly' + loris[u'poisonous-part'] = u'elbows' + session.save(loris) + session.flush() + + q = (session.query(Animal). + filter(Animal.facts.any( + and_(AnimalFact.key == u'color', + AnimalFact.value == u'reddish')))) + print 'reddish animals', q.all() + + # Save some typing by wrapping that up in a function: + with_characteristic = lambda key, value: and_(AnimalFact.key == key, + AnimalFact.value == value) + + q = (session.query(Animal). + filter(Animal.facts.any( + with_characteristic(u'color', u'brown')))) + print 'brown animals', q.all() + + q = (session.query(Animal). + filter(not_(Animal.facts.any( + with_characteristic(u'poisonous-part', u'elbows'))))) + print 'animals without poisonous-part == elbows', q.all() + + q = (session.query(Animal). + filter(Animal.facts.any(AnimalFact.value == u'somewhat'))) + print 'any animal with any .value of "somewhat"', q.all() + + # Facts can be queried as well. + q = (session.query(AnimalFact). + filter(with_characteristic(u'cuteness', u'very'))) + print 'just the facts', q.all() + + + metadata.drop_all() diff --git a/examples/vertical/vertical.py b/examples/vertical/vertical.py index e3b48c3369..225beeffe9 100644 --- a/examples/vertical/vertical.py +++ b/examples/vertical/vertical.py @@ -10,6 +10,8 @@ import datetime e = MetaData('sqlite://') e.bind.echo = True +Session = scoped_session(sessionmaker(transactional=True)) + # this table represents Entity objects. each Entity gets a row in this table, # with a primary key and a title. entities = Table('entities', e, @@ -84,10 +86,7 @@ class EntityValue(object): the value to the underlying datatype of its EntityField.""" def __init__(self, key=None, value=None): if key is not None: - sess = create_session() - self.field = sess.query(EntityField).get_by(name=key) or EntityField(key) - # close the session, which will make a loaded EntityField a detached instance - sess.close() + self.field = Session.query(EntityField).filter(EntityField.name==key).first() or EntityField(key) if self.field.datatype is None: if isinstance(value, int): self.field.datatype = 'int' @@ -123,7 +122,7 @@ mapper(Entity, entities, properties = { # create two entities. the objects can be used about as regularly as # any object can. -session = create_session() +session = Session() entity = Entity() entity.title = 'this is the first entity' entity.name = 'this is the name' diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 6e95fd7e1d..343a0cac8c 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -1,28 +1,34 @@ # __init__.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy.types import * -from sqlalchemy.sql import * -from sqlalchemy.schema import * +import inspect +from sqlalchemy.types import \ + BLOB, BOOLEAN, CHAR, CLOB, DATE, DATETIME, DECIMAL, FLOAT, INT, \ + NCHAR, NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, \ + Binary, Boolean, Date, DateTime, Float, Integer, Interval, Numeric, \ + PickleType, SmallInteger, String, Text, Time, Unicode, UnicodeText -from sqlalchemy.engine import create_engine +from sqlalchemy.sql import \ + func, modifier, text, literal, literal_column, null, alias, \ + and_, or_, not_, \ + select, subquery, union, union_all, insert, update, delete, \ + join, outerjoin, \ + bindparam, outparam, asc, desc, collate, \ + except_, except_all, exists, intersect, intersect_all, \ + between, case, cast, distinct, extract -def __figure_version(): - try: - from pkg_resources import require - import os - # NOTE: this only works when the package is either installed, - # or has an .egg-info directory present (i.e. wont work with raw SVN checkout) - info = require('sqlalchemy')[0] - if os.path.dirname(os.path.dirname(__file__)) == info.location: - return info.version - else: - return '(not installed)' - except: - return '(not installed)' - -__version__ = __figure_version() - +from sqlalchemy.schema import \ + MetaData, ThreadLocalMetaData, Table, Column, ForeignKey, \ + Sequence, Index, ForeignKeyConstraint, PrimaryKeyConstraint, \ + CheckConstraint, UniqueConstraint, Constraint, \ + PassiveDefault, ColumnDefault, DDL + +from sqlalchemy.engine import create_engine, engine_from_config + +__all__ = [ name for name, obj in locals().items() + if not (name.startswith('_') or inspect.ismodule(obj)) ] + +__version__ = 'svn' diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py deleted file mode 100644 index 22227d56a8..0000000000 --- a/lib/sqlalchemy/ansisql.py +++ /dev/null @@ -1,1075 +0,0 @@ -# ansisql.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Defines ANSI SQL operations. - -Contains default implementations for the abstract objects in the sql -module. -""" - -import string, re, sets, operator - -from sqlalchemy import schema, sql, engine, util, exceptions -from sqlalchemy.engine import default - -ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', - 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', - 'SESSION_USER', 'USER']) - - -RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', - 'as', 'asc', 'asymmetric', 'authorization', 'between', - 'binary', 'both', 'case', 'cast', 'check', 'collate', - 'column', 'constraint', 'create', 'cross', 'current_date', - 'current_role', 'current_time', 'current_timestamp', - 'current_user', 'default', 'deferrable', 'desc', - 'distinct', 'do', 'else', 'end', 'except', 'false', - 'for', 'foreign', 'freeze', 'from', 'full', 'grant', - 'group', 'having', 'ilike', 'in', 'initially', 'inner', - 'intersect', 'into', 'is', 'isnull', 'join', 'leading', - 'left', 'like', 'limit', 'localtime', 'localtimestamp', - 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', - 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', - 'placing', 'primary', 'references', 'right', 'select', - 'session_user', 'similar', 'some', 'symmetric', 'table', - 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', - 'using', 'verbose', 'when', 'where']) - -LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$') -ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') - -BIND_PARAMS = re.compile(r'(?', - operator.ge : '>=', - operator.eq : '=', - sql.ColumnOperators.concat_op : '||', - sql.ColumnOperators.like_op : 'LIKE', - sql.ColumnOperators.notlike_op : 'NOT LIKE', - sql.ColumnOperators.ilike_op : 'ILIKE', - sql.ColumnOperators.notilike_op : 'NOT ILIKE', - sql.ColumnOperators.between_op : 'BETWEEN', - sql.ColumnOperators.in_op : 'IN', - sql.ColumnOperators.notin_op : 'NOT IN', - sql.ColumnOperators.comma_op : ', ', - sql.Operators.from_ : 'FROM', - sql.Operators.as_ : 'AS', - sql.Operators.exists : 'EXISTS', - sql.Operators.is_ : 'IS', - sql.Operators.isnot : 'IS NOT' -} - -class ANSIDialect(default.DefaultDialect): - def __init__(self, cache_identifiers=True, **kwargs): - super(ANSIDialect,self).__init__(**kwargs) - self.identifier_preparer = self.preparer() - self.cache_identifiers = cache_identifiers - - def create_connect_args(self): - return ([],{}) - - def schemagenerator(self, *args, **kwargs): - return ANSISchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return ANSISchemaDropper(self, *args, **kwargs) - - def compiler(self, statement, parameters, **kwargs): - return ANSICompiler(self, statement, parameters, **kwargs) - - def preparer(self): - """Return an IdentifierPreparer. - - This object is used to format table and column names including - proper quoting and case conventions. - """ - return ANSIIdentifierPreparer(self) - -class ANSICompiler(engine.Compiled, sql.ClauseVisitor): - """Default implementation of Compiled. - - Compiles ClauseElements into ANSI-compliant SQL strings. - """ - - __traverse_options__ = {'column_collections':False, 'entry':True} - - operators = OPERATORS - - def __init__(self, dialect, statement, parameters=None, **kwargs): - """Construct a new ``ANSICompiler`` object. - - dialect - Dialect to be used - - statement - ClauseElement to be compiled - - parameters - optional dictionary indicating a set of bind parameters - specified with this Compiled object. These parameters are - the *default* key/value pairs when the Compiled is executed, - and also may affect the actual compilation, as in the case - of an INSERT where the actual columns inserted will - correspond to the keys present in the parameters. - """ - - super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs) - - # if we are insert/update. set to true when we visit an INSERT or UPDATE - self.isinsert = self.isupdate = False - - # a dictionary of bind parameter keys to _BindParamClause instances. - self.binds = {} - - # a dictionary of _BindParamClause instances to "compiled" names that are - # actually present in the generated SQL - self.bind_names = {} - - # when the compiler visits a SELECT statement, the clause object is appended - # to this stack. various visit operations will check this stack to determine - # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only - # apply to the topmost-level SELECT statement ?) - self.select_stack = [] - - # a dictionary of result-set column names (strings) to TypeEngine instances, - # which will be passed to a ResultProxy and used for resultset-level value conversion - self.typemap = {} - - # a dictionary of select columns labels mapped to their "generated" label - self.column_labels = {} - - # a dictionary of ClauseElement subclasses to counters, which are used to - # generate truncated identifier names or "anonymous" identifiers such as - # for aliases - self.generated_ids = {} - - # default formatting style for bind parameters - self.bindtemplate = ":%s" - - # paramstyle from the dialect (comes from DBAPI) - self.paramstyle = dialect.paramstyle - - # true if the paramstyle is positional - self.positional = dialect.positional - - # a list of the compiled's bind parameter names, used to help - # formulate a positional argument list - self.positiontup = [] - - # an ANSIIdentifierPreparer that formats the quoting of identifiers - self.preparer = dialect.identifier_preparer - - # a dictionary containing attributes about all select() - # elements located within the clause, regarding which are subqueries, which are - # selected from, and which elements should be correlated to an enclosing select. - # used mostly to determine the list of FROM elements for each select statement, as well - # as some dialect-specific rules regarding subqueries. - self.correlate_state = {} - - # for UPDATE and INSERT statements, a set of columns whos values are being set - # from a SQL expression (i.e., not one of the bind parameter values). if present, - # default-value logic in the Dialect knows not to fire off column defaults - # and also knows postfetching will be needed to get the values represented by these - # parameters. - self.inline_params = None - - def after_compile(self): - # this re will search for params like :param - # it has a negative lookbehind for an extra ':' so that it doesnt match - # postgres '::text' tokens - text = self.string - if ':' not in text: - return - - if self.paramstyle=='pyformat': - text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text) - elif self.positional: - params = BIND_PARAMS.finditer(text) - for p in params: - self.positiontup.append(p.group(1)) - if self.paramstyle=='qmark': - text = BIND_PARAMS.sub('?', text) - elif self.paramstyle=='format': - text = BIND_PARAMS.sub('%s', text) - elif self.paramstyle=='numeric': - i = [0] - def getnum(x): - i[0] += 1 - return str(i[0]) - text = BIND_PARAMS.sub(getnum, text) - # un-escape any \:params - text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text) - self.string = text - - def compile(self): - self.string = self.process(self.statement) - self.after_compile() - - def process(self, obj, **kwargs): - return self.traverse_single(obj, **kwargs) - - def is_subquery(self, select): - return self.correlate_state[select].get('is_subquery', False) - - def get_whereclause(self, obj): - """given a FROM clause, return an additional WHERE condition that should be - applied to a SELECT. - - Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN - constructs in non-ansi mode. - """ - - return None - - def construct_params(self, params): - """Return a sql.ClauseParameters object. - - Combines the given bind parameter dictionary (string keys to object values) - with the _BindParamClause objects stored within this Compiled object - to produce a ClauseParameters structure, representing the bind arguments - for a single statement execution, or one element of an executemany execution. - """ - - if self.parameters is not None: - bindparams = self.parameters.copy() - else: - bindparams = {} - bindparams.update(params) - d = sql.ClauseParameters(self.dialect, self.positiontup) - for b in self.binds.values(): - name = self.bind_names[b] - d.set_parameter(b, b.value, name) - - for key, value in bindparams.iteritems(): - try: - b = self.binds[key] - except KeyError: - continue - name = self.bind_names[b] - d.set_parameter(b, value, name) - - return d - - def default_from(self): - """Called when a SELECT statement has no froms, and no FROM clause is to be appended. - - Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. - """ - - return "" - - def visit_grouping(self, grouping, **kwargs): - return "(" + self.process(grouping.elem) + ")" - - def visit_label(self, label): - labelname = self._truncated_identifier("colident", label.name) - - if len(self.select_stack): - self.typemap.setdefault(labelname.lower(), label.obj.type) - if isinstance(label.obj, sql._ColumnClause): - self.column_labels[label.obj._label] = labelname - self.column_labels[label.name] = labelname - return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)]) - - def visit_column(self, column, **kwargs): - # there is actually somewhat of a ruleset when you would *not* necessarily - # want to truncate a column identifier, if its mapped to the name of a - # physical column. but thats very hard to identify at this point, and - # the identifier length should be greater than the id lengths of any physical - # columns so should not matter. - if not column.is_literal: - name = self._truncated_identifier("colident", column.name) - else: - name = column.name - - if len(self.select_stack): - # if we are within a visit to a Select, set up the "typemap" - # for this column which is used to translate result set values - self.typemap.setdefault(name.lower(), column.type) - self.column_labels.setdefault(column._label, name.lower()) - - if column.table is None or not column.table.named_with_column(): - return self.preparer.format_column(column, name=name) - else: - if column.table.oid_column is column: - n = self.dialect.oid_column_name(column) - if n is not None: - return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n) - elif len(column.table.primary_key) != 0: - pk = list(column.table.primary_key)[0] - pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) - return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name)) - else: - return None - else: - return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name)) - - - def visit_fromclause(self, fromclause, **kwargs): - return fromclause.name - - def visit_index(self, index, **kwargs): - return index.name - - def visit_typeclause(self, typeclause, **kwargs): - return typeclause.type.dialect_impl(self.dialect).get_col_spec() - - def visit_textclause(self, textclause, **kwargs): - for bind in textclause.bindparams.values(): - self.process(bind) - if textclause.typemap is not None: - self.typemap.update(textclause.typemap) - return textclause.text - - def visit_null(self, null, **kwargs): - return 'NULL' - - def visit_clauselist(self, clauselist, **kwargs): - sep = clauselist.operator - if sep is None: - sep = " " - elif sep == sql.ColumnOperators.comma_op: - sep = ', ' - else: - sep = " " + self.operator_string(clauselist.operator) + " " - return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep) - - def apply_function_parens(self, func): - return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - - def visit_calculatedclause(self, clause, **kwargs): - return self.process(clause.clause_expr) - - def visit_cast(self, cast, **kwargs): - if len(self.select_stack): - # not sure if we want to set the typemap here... - self.typemap.setdefault("CAST", cast.type) - return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) - - def visit_function(self, func, **kwargs): - if len(self.select_stack): - self.typemap.setdefault(func.name, func.type) - if not self.apply_function_parens(func): - return ".".join(func.packagenames + [func.name]) - else: - return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) - - def visit_compound_select(self, cs, asfrom=False, **kwargs): - text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ") - group_by = self.process(cs._group_by_clause) - if group_by: - text += " GROUP BY " + group_by - text += self.order_by_clause(cs) - text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" - - if asfrom: - return "(" + text + ")" - else: - return text - - def visit_unary(self, unary, **kwargs): - s = self.process(unary.element) - if unary.operator: - s = self.operator_string(unary.operator) + " " + s - if unary.modifier: - s = s + " " + unary.modifier - return s - - def visit_binary(self, binary, **kwargs): - op = self.operator_string(binary.operator) - if callable(op): - return op(self.process(binary.left), self.process(binary.right)) - else: - return self.process(binary.left) + " " + op + " " + self.process(binary.right) - - def operator_string(self, operator): - return self.operators.get(operator, str(operator)) - - def visit_bindparam(self, bindparam, **kwargs): - # apply truncation to the ultimate generated name - - if bindparam.shortname != bindparam.key: - self.binds.setdefault(bindparam.shortname, bindparam) - - if bindparam.unique: - count = 1 - key = bindparam.key - # redefine the generated name of the bind param in the case - # that we have multiple conflicting bind parameters. - while self.binds.setdefault(key, bindparam) is not bindparam: - tag = "_%d" % count - key = bindparam.key + tag - count += 1 - bindparam.key = key - return self.bindparam_string(self._truncate_bindparam(bindparam)) - else: - existing = self.binds.get(bindparam.key) - if existing is not None and existing.unique: - raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) - self.binds[bindparam.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) - - def _truncate_bindparam(self, bindparam): - if bindparam in self.bind_names: - return self.bind_names[bindparam] - - bind_name = bindparam.key - bind_name = self._truncated_identifier("bindparam", bind_name) - # add to bind_names for translation - self.bind_names[bindparam] = bind_name - - return bind_name - - def _truncated_identifier(self, ident_class, name): - if (ident_class, name) in self.generated_ids: - return self.generated_ids[(ident_class, name)] - - anonname = self._anonymize(name) - if len(anonname) > self.dialect.max_identifier_length(): - counter = self.generated_ids.get(ident_class, 1) - truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:] - self.generated_ids[ident_class] = counter + 1 - else: - truncname = anonname - self.generated_ids[(ident_class, name)] = truncname - return truncname - - def _anonymize(self, name): - def anon(match): - (ident, derived) = match.group(1,2) - if ('anonymous', ident) in self.generated_ids: - return self.generated_ids[('anonymous', ident)] - else: - anonymous_counter = self.generated_ids.get('anonymous', 1) - newname = derived + "_" + str(anonymous_counter) - self.generated_ids['anonymous'] = anonymous_counter + 1 - self.generated_ids[('anonymous', ident)] = newname - return newname - return re.sub(r'{ANON (-?\d+) (.*)}', anon, name) - - def bindparam_string(self, name): - return self.bindtemplate % name - - def visit_alias(self, alias, asfrom=False, **kwargs): - if asfrom: - return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) - else: - return self.process(alias.original, **kwargs) - - def label_select_column(self, select, column): - """convert a column from a select's "columns" clause. - - given a select() and a column element from its inner_columns collection, return a - Label object if this column should be labeled in the columns clause. Otherwise, - return None and the column will be used as-is. - - The calling method will traverse the returned label to acquire its string - representation. - """ - - # SQLite doesnt like selecting from a subquery where the column - # names look like table.colname. so if column is in a "selected from" - # subquery, label it synoymously with its column name - if \ - self.correlate_state[select].get('is_selected_from', False) and \ - isinstance(column, sql._ColumnClause) and \ - not column.is_literal and \ - column.table is not None and \ - not isinstance(column.table, sql.Select): - return column.label(column.name) - else: - return None - - def visit_select(self, select, asfrom=False, **kwargs): - - select._calculate_correlations(self.correlate_state) - self.select_stack.append(select) - - # the actual list of columns to print in the SELECT column list. - inner_columns = util.OrderedSet() - - froms = select._get_display_froms(self.correlate_state) - - for co in select.inner_columns: - if select.use_labels: - labelname = co._label - if labelname is not None: - l = co.label(labelname) - inner_columns.add(self.process(l)) - else: - self.traverse(co) - inner_columns.add(self.process(co)) - else: - l = self.label_select_column(select, co) - if l is not None: - inner_columns.add(self.process(l)) - else: - inner_columns.add(self.process(co)) - - self.select_stack.pop(-1) - - collist = string.join(inner_columns.difference(util.Set([None])), ', ') - - text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " - text += self.get_select_precolumns(select) - text += collist - - whereclause = select._whereclause - - from_strings = [] - for f in froms: - from_strings.append(self.process(f, asfrom=True)) - - w = self.get_whereclause(f) - if w is not None: - if whereclause is not None: - whereclause = sql.and_(w, whereclause) - else: - whereclause = w - - if len(froms): - text += " \nFROM " - text += string.join(from_strings, ', ') - else: - text += self.default_from() - - if whereclause is not None: - t = self.process(whereclause) - if t: - text += " \nWHERE " + t - - group_by = self.process(select._group_by_clause) - if group_by: - text += " GROUP BY " + group_by - - if select._having is not None: - t = self.process(select._having) - if t: - text += " \nHAVING " + t - - text += self.order_by_clause(select) - text += (select._limit or select._offset) and self.limit_clause(select) or "" - text += self.for_update_clause(select) - - if asfrom: - return "(" + text + ")" - else: - return text - - def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just before column list.""" - return select._distinct and "DISTINCT " or "" - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - if order_by: - return " ORDER BY " + order_by - else: - return "" - - def for_update_clause(self, select): - if select.for_update: - return " FOR UPDATE" - else: - return "" - - def limit_clause(self, select): - text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - text += " \n LIMIT -1" - text += " OFFSET " + str(select._offset) - return text - - def visit_table(self, table, asfrom=False, **kwargs): - if asfrom: - return self.preparer.format_table(table) - else: - return "" - - def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ - self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) - - def uses_sequences_for_inserts(self): - return False - - def visit_insert(self, insert_stmt): - - # search for columns who will be required to have an explicit bound value. - # for inserts, this includes Python-side defaults, columns with sequences for dialects - # that support sequences, and primary key columns for dialects that explicitly insert - # pre-generated primary key values - required_cols = util.Set() - class DefaultVisitor(schema.SchemaVisitor): - def visit_column(s, cd): - if c.primary_key and self.uses_sequences_for_inserts(): - required_cols.add(c) - def visit_column_default(s, cd): - required_cols.add(c) - def visit_sequence(s, seq): - if self.uses_sequences_for_inserts(): - required_cols.add(c) - vis = DefaultVisitor() - for c in insert_stmt.table.c: - if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): - vis.traverse(c) - - self.isinsert = True - colparams = self._get_colparams(insert_stmt, required_cols) - - return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + - " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") - - def visit_update(self, update_stmt): - update_stmt._calculate_correlations(self.correlate_state) - - # search for columns who will be required to have an explicit bound value. - # for updates, this includes Python-side "onupdate" defaults. - required_cols = util.Set() - class OnUpdateVisitor(schema.SchemaVisitor): - def visit_column_onupdate(s, cd): - required_cols.add(c) - vis = OnUpdateVisitor() - for c in update_stmt.table.c: - if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): - vis.traverse(c) - - self.isupdate = True - colparams = self._get_colparams(update_stmt, required_cols) - - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') - - if update_stmt._whereclause: - text += " WHERE " + self.process(update_stmt._whereclause) - - return text - - def _get_colparams(self, stmt, required_cols): - """create a set of tuples representing column/string pairs for use - in an INSERT or UPDATE statement. - - This method may generate new bind params within this compiled - based on the given set of "required columns", which are required - to have a value set in the statement. - """ - - def create_bind_param(col, value): - bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True) - self.binds[col.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) - - # no parameters in the statement, no parameters in the - # compiled params - return binds for all columns - if self.parameters is None and stmt.parameters is None: - return [(c, create_bind_param(c, None)) for c in stmt.table.columns] - - def create_clause_param(col, value): - self.traverse(value) - self.inline_params.add(col) - return self.process(value) - - self.inline_params = util.Set() - - def to_col(key): - if not isinstance(key, sql._ColumnClause): - return stmt.table.columns.get(unicode(key), key) - else: - return key - - # if we have statement parameters - set defaults in the - # compiled params - if self.parameters is None: - parameters = {} - else: - parameters = dict([(to_col(k), v) for k, v in self.parameters.iteritems()]) - - if stmt.parameters is not None: - for k, v in stmt.parameters.iteritems(): - parameters.setdefault(to_col(k), v) - - for col in required_cols: - parameters.setdefault(col, None) - - # create a list of column assignment clauses as tuples - values = [] - for c in stmt.table.columns: - if c in parameters: - value = parameters[c] - if sql._is_literal(value): - value = create_bind_param(c, value) - else: - value = create_clause_param(c, value) - values.append((c, value)) - - return values - - def visit_delete(self, delete_stmt): - delete_stmt._calculate_correlations(self.correlate_state) - - text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) - - if delete_stmt._whereclause: - text += " WHERE " + self.process(delete_stmt._whereclause) - - return text - - def visit_savepoint(self, savepoint_stmt): - return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) - - def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) - - def visit_release_savepoint(self, savepoint_stmt): - return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) - - def __str__(self): - return self.string - -class ANSISchemaBase(engine.SchemaIterator): - def find_alterables(self, tables): - alterables = [] - class FindAlterables(schema.SchemaVisitor): - def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and constraint.table in tables: - alterables.append(constraint) - findalterables = FindAlterables() - for table in tables: - for c in table.constraints: - findalterables.traverse(c) - return alterables - -class ANSISchemaGenerator(ANSISchemaBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables and util.Set(tables) or None - self.preparer = dialect.preparer() - self.dialect = dialect - - def get_column_specification(self, column, first_pk=False): - raise NotImplementedError() - - def visit_metadata(self, metadata): - collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] - for table in collection: - self.traverse_single(table) - if self.dialect.supports_alter(): - for alterable in self.find_alterables(collection): - self.add_foreignkey(alterable) - - def visit_table(self, table): - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) - - self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") - - separator = "\n" - - # if only one primary key, specify it along with the column - first_pk = False - for column in table.columns: - self.append(separator) - separator = ", \n" - self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) - if column.primary_key: - first_pk = True - for constraint in column.constraints: - self.traverse_single(constraint) - - # On some DB order is significant: visit PK first, then the - # other constraints (engine.ReflectionTest.testbasic failed on FB2) - if len(table.primary_key): - self.traverse_single(table.primary_key) - for constraint in [c for c in table.constraints if c is not table.primary_key]: - self.traverse_single(constraint) - - self.append("\n)%s\n\n" % self.post_create_table(table)) - self.execute() - if hasattr(table, 'indexes'): - for index in table.indexes: - self.traverse_single(index) - - def post_create_table(self, table): - return '' - - def get_column_default_string(self, column): - if isinstance(column.default, schema.PassiveDefault): - if isinstance(column.default.arg, basestring): - return "'%s'" % column.default.arg - else: - return str(self._compile(column.default.arg, None)) - else: - return None - - def _compile(self, tocompile, parameters): - """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.dialect.compiler(tocompile, parameters) - compiler.compile() - return compiler - - def visit_check_constraint(self, constraint): - self.append(", \n\t") - if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" CHECK (%s)" % constraint.sqltext) - - def visit_column_check_constraint(self, constraint): - self.append(" CHECK (%s)" % constraint.sqltext) - - def visit_primary_key_constraint(self, constraint): - if len(constraint) == 0: - return - self.append(", \n\t") - if constraint.name is not None: - self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) - - def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and self.dialect.supports_alter(): - return - self.append(", \n\t ") - self.define_foreign_key(constraint) - - def add_foreignkey(self, constraint): - self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) - self.define_foreign_key(constraint) - self.execute() - - def define_foreign_key(self, constraint): - preparer = self.preparer - if constraint.name is not None: - self.append("CONSTRAINT %s " % - preparer.format_constraint(constraint)) - self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join([preparer.format_column(f.parent) for f in constraint.elements]), - preparer.format_table(list(constraint.elements)[0].column.table), - ', '.join([preparer.format_column(f.column) for f in constraint.elements]) - )) - if constraint.ondelete is not None: - self.append(" ON DELETE %s" % constraint.ondelete) - if constraint.onupdate is not None: - self.append(" ON UPDATE %s" % constraint.onupdate) - - def visit_unique_constraint(self, constraint): - self.append(", \n\t") - if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint]))) - - def visit_column(self, column): - pass - - def visit_index(self, index): - preparer = self.preparer - self.append("CREATE ") - if index.unique: - self.append("UNIQUE ") - self.append("INDEX %s ON %s (%s)" \ - % (preparer.format_index(index), - preparer.format_table(index.table), - string.join([preparer.format_column(c) for c in index.columns], ', '))) - self.execute() - -class ANSISchemaDropper(ANSISchemaBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables - self.preparer = dialect.preparer() - self.dialect = dialect - - def visit_metadata(self, metadata): - collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] - if self.dialect.supports_alter(): - for alterable in self.find_alterables(collection): - self.drop_foreignkey(alterable) - for table in collection: - self.traverse_single(table) - - def visit_index(self, index): - self.append("\nDROP INDEX " + self.preparer.format_index(index)) - self.execute() - - def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( - self.preparer.format_table(constraint.table), - self.preparer.format_constraint(constraint))) - self.execute() - - def visit_table(self, table): - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) - - self.append("\nDROP TABLE " + self.preparer.format_table(table)) - self.execute() - -class ANSIDefaultRunner(engine.DefaultRunner): - pass - -class ANSIIdentifierPreparer(object): - """Handle quoting and case-folding of identifiers based on options.""" - - def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): - """Construct a new ``ANSIIdentifierPreparer`` object. - - initial_quote - Character that begins a delimited identifier. - - final_quote - Character that ends a delimited identifier. Defaults to `initial_quote`. - - omit_schema - Prevent prepending schema name. Useful for databases that do - not support schemae. - """ - - self.dialect = dialect - self.initial_quote = initial_quote - self.final_quote = final_quote or self.initial_quote - self.omit_schema = omit_schema - self.__strings = {} - - def _escape_identifier(self, value): - """Escape an identifier. - - Subclasses should override this to provide database-dependent - escaping behavior. - """ - - return value.replace('"', '""') - - def quote_identifier(self, value): - """Quote an identifier. - - Subclasses should override this to provide database-dependent - quoting behavior. - """ - - return self.initial_quote + self._escape_identifier(value) + self.final_quote - - def _fold_identifier_case(self, value): - """Fold the case of an identifier. - - Subclasses should override this to provide database-dependent - case folding behavior. - """ - - return value - # ANSI SQL calls for the case of all unquoted identifiers to be folded to UPPER. - # some tests would need to be rewritten if this is done. - #return value.upper() - - def _reserved_words(self): - return RESERVED_WORDS - - def _legal_characters(self): - return LEGAL_CHARACTERS - - def _illegal_initial_characters(self): - return ILLEGAL_INITIAL_CHARACTERS - - def _requires_quotes(self, value, case_sensitive): - """Return True if the given identifier requires quoting.""" - return \ - value in self._reserved_words() \ - or (value[0] in self._illegal_initial_characters()) \ - or bool(len([x for x in unicode(value) if x not in self._legal_characters()])) \ - or (case_sensitive and value.lower() != value) - - def __generic_obj_format(self, obj, ident): - if getattr(obj, 'quote', False): - return self.quote_identifier(ident) - if self.dialect.cache_identifiers: - case_sens = getattr(obj, 'case_sensitive', None) - try: - return self.__strings[(ident, case_sens)] - except KeyError: - if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): - self.__strings[(ident, case_sens)] = self.quote_identifier(ident) - else: - self.__strings[(ident, case_sens)] = ident - return self.__strings[(ident, case_sens)] - else: - if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): - return self.quote_identifier(ident) - else: - return ident - - def should_quote(self, object): - return object.quote or self._requires_quotes(object.name, object.case_sensitive) - - def format_sequence(self, sequence): - return self.__generic_obj_format(sequence, sequence.name) - - def format_label(self, label, name=None): - return self.__generic_obj_format(label, name or label.name) - - def format_alias(self, alias, name=None): - return self.__generic_obj_format(alias, name or alias.name) - - def format_savepoint(self, savepoint): - return self.__generic_obj_format(savepoint, savepoint) - - def format_constraint(self, constraint): - return self.__generic_obj_format(constraint, constraint.name) - - def format_index(self, index): - return self.__generic_obj_format(index, index.name) - - def format_table(self, table, use_schema=True, name=None): - """Prepare a quoted table and schema name.""" - - if name is None: - name = table.name - result = self.__generic_obj_format(table, name) - if use_schema and getattr(table, "schema", None): - result = self.__generic_obj_format(table, table.schema) + "." + result - return result - - def format_column(self, column, use_table=False, name=None, table_name=None): - """Prepare a quoted column name.""" - if name is None: - name = column.name - if not getattr(column, 'is_literal', False): - if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name) - else: - return self.__generic_obj_format(column, name) - else: - # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted - if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + name - else: - return name - - def format_column_with_table(self, column, column_name=None, table_name=None): - """Prepare a quoted column name with table name.""" - - return self.format_column(column, use_table=True, name=column_name, table_name=table_name) - -dialect = ANSIDialect diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index b630473778..7bb8356f2e 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -1,8 +1,11 @@ # __init__.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -__all__ = ['sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird'] +__all__ = [ + 'sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird', + 'sybase', 'access', 'maxdb' + ] diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py new file mode 100644 index 0000000000..38dba17a5a --- /dev/null +++ b/lib/sqlalchemy/databases/access.py @@ -0,0 +1,430 @@ +# access.py +# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk +# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy import sql, schema, types, exceptions, pool +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base + + +class AcNumeric(types.Numeric): + def result_processor(self, dialect): + return None + + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process + + def get_col_spec(self): + return "NUMERIC" + +class AcFloat(types.Float): + def get_col_spec(self): + return "FLOAT" + + def bind_processor(self, dialect): + """By converting to string, we can use Decimal types round-trip.""" + def process(value): + if not value is None: + return str(value) + return None + return process + +class AcInteger(types.Integer): + def get_col_spec(self): + return "INTEGER" + +class AcTinyInteger(types.Integer): + def get_col_spec(self): + return "TINYINT" + +class AcSmallInteger(types.Smallinteger): + def get_col_spec(self): + return "SMALLINT" + +class AcDateTime(types.DateTime): + def __init__(self, *a, **kw): + super(AcDateTime, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class AcDate(types.Date): + def __init__(self, *a, **kw): + super(AcDate, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class AcText(types.Text): + def get_col_spec(self): + return "MEMO" + +class AcString(types.String): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + +class AcUnicode(types.Unicode): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + return None + +class AcChar(types.CHAR): + def get_col_spec(self): + return "TEXT" + (self.length and ("(%d)" % self.length) or "") + +class AcBinary(types.Binary): + def get_col_spec(self): + return "BINARY" + +class AcBoolean(types.Boolean): + def get_col_spec(self): + return "YESNO" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class AcTimeStamp(types.TIMESTAMP): + def get_col_spec(self): + return "TIMESTAMP" + +def descriptor(): + return {'name':'access', + 'description':'Microsoft Access', + 'arguments':[ + ('user',"Database user name",None), + ('password',"Database password",None), + ('db',"Path to database file",None), + ]} + +class AccessExecutionContext(default.DefaultExecutionContext): + def _has_implicit_sequence(self, column): + if column.primary_key and column.autoincrement: + if isinstance(column.type, types.Integer) and not column.foreign_key: + if column.default is None or (isinstance(column.default, schema.Sequence) and \ + column.default.optional): + return True + return False + + def post_exec(self): + """If we inserted into a row with a COUNTER column, fetch the ID""" + + if self.compiled.isinsert: + tbl = self.compiled.statement.table + if not hasattr(tbl, 'has_sequence'): + tbl.has_sequence = None + for column in tbl.c: + if getattr(column, 'sequence', False) or self._has_implicit_sequence(column): + tbl.has_sequence = column + break + + if bool(tbl.has_sequence): + # TBD: for some reason _last_inserted_ids doesn't exist here + # (but it does at corresponding point in mssql???) + #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:] + # print "LAST ROW ID", self._last_inserted_ids + + super(AccessExecutionContext, self).post_exec() + + +const, daoEngine = None, None +class AccessDialect(default.DefaultDialect): + colspecs = { + types.Unicode : AcUnicode, + types.Integer : AcInteger, + types.Smallinteger: AcSmallInteger, + types.Numeric : AcNumeric, + types.Float : AcFloat, + types.DateTime : AcDateTime, + types.Date : AcDate, + types.String : AcString, + types.Binary : AcBinary, + types.Boolean : AcBoolean, + types.Text : AcText, + types.CHAR: AcChar, + types.TIMESTAMP: AcTimeStamp, + } + + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + def type_descriptor(self, typeobj): + newobj = types.adapt_type(typeobj, self.colspecs) + return newobj + + def __init__(self, **params): + super(AccessDialect, self).__init__(**params) + self.text_as_varchar = False + self._dtbs = None + + def dbapi(cls): + import win32com.client, pythoncom + + global const, daoEngine + if const is None: + const = win32com.client.constants + for suffix in (".36", ".35", ".30"): + try: + daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix) + break + except pythoncom.com_error: + pass + else: + raise exceptions.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.") + + import pyodbc as module + return module + dbapi = classmethod(dbapi) + + def create_connect_args(self, url): + opts = url.translate_connect_args() + connectors = ["Driver={Microsoft Access Driver (*.mdb)}"] + connectors.append("Dbq=%s" % opts["database"]) + user = opts.get("username", None) + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % opts.get("password", "")) + return [[";".join(connectors)], {}] + + def create_execution_context(self, *args, **kwargs): + return AccessExecutionContext(self, *args, **kwargs) + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def do_execute(self, cursor, statement, params, **kwargs): + if params == {}: + params = () + super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs) + + def _execute(self, c, statement, parameters): + try: + if parameters == {}: + parameters = () + c.execute(statement, parameters) + self.context.rowcount = c.rowcount + except Exception, e: + raise exceptions.DBAPIError.instance(statement, parameters, e) + + def has_table(self, connection, tablename, schema=None): + # This approach seems to be more reliable that using DAO + try: + connection.execute('select top 1 * from [%s]' % tablename) + return True + except Exception, e: + return False + + def reflecttable(self, connection, table, include_columns): + # This is defined in the function, as it relies on win32com constants, + # that aren't imported until dbapi method is called + if not hasattr(self, 'ischema_names'): + self.ischema_names = { + const.dbByte: AcBinary, + const.dbInteger: AcInteger, + const.dbLong: AcInteger, + const.dbSingle: AcFloat, + const.dbDouble: AcFloat, + const.dbDate: AcDateTime, + const.dbLongBinary: AcBinary, + const.dbMemo: AcText, + const.dbBoolean: AcBoolean, + const.dbText: AcUnicode, # All Access strings are unicode + } + + # A fresh DAO connection is opened for each reflection + # This is necessary, so we get the latest updates + dtbs = daoEngine.OpenDatabase(connection.engine.url.database) + + try: + for tbl in dtbs.TableDefs: + if tbl.Name.lower() == table.name.lower(): + break + else: + raise exceptions.NoSuchTableError(table.name) + + for col in tbl.Fields: + coltype = self.ischema_names[col.Type] + if col.Type == const.dbText: + coltype = coltype(col.Size) + + colargs = \ + { + 'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField), + } + default = col.DefaultValue + + if col.Attributes & const.dbAutoIncrField: + colargs['default'] = schema.Sequence(col.Name + '_seq') + elif default: + if col.Type == const.dbBoolean: + default = default == 'Yes' and '1' or '0' + colargs['default'] = schema.PassiveDefault(sql.text(default)) + + table.append_column(schema.Column(col.Name, coltype, **colargs)) + + # TBD: check constraints + + # Find primary key columns first + for idx in tbl.Indexes: + if idx.Primary: + for col in idx.Fields: + thecol = table.c[col.Name] + table.primary_key.add(thecol) + if isinstance(thecol.type, AcInteger) and \ + not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)): + thecol.autoincrement = False + + # Then add other indexes + for idx in tbl.Indexes: + if not idx.Primary: + if len(idx.Fields) == 1: + col = table.c[idx.Fields[0].Name] + if not col.primary_key: + col.index = True + col.unique = idx.Unique + else: + pass # TBD: multi-column indexes + + + for fk in dtbs.Relations: + if fk.ForeignTable != table.name: + continue + scols = [c.ForeignName for c in fk.Fields] + rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields] + table.append_constraint(schema.ForeignKeyConstraint(scols, rcols)) + + finally: + dtbs.Close() + + def table_names(self, connection, schema): + # A fresh DAO connection is opened for each reflection + # This is necessary, so we get the latest updates + dtbs = daoEngine.OpenDatabase(connection.engine.url.database) + + names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] <> "~TMP"] + dtbs.Close() + return names + + +class AccessCompiler(compiler.DefaultCompiler): + def visit_select_precolumns(self, select): + """Access puts TOP, it's version of LIMIT here """ + s = select.distinct and "DISTINCT " or "" + if select.limit: + s += "TOP %s " % (select.limit) + if select.offset: + raise exceptions.InvalidRequestError('Access does not support LIMIT with an offset') + return s + + def limit_clause(self, select): + """Limit in access is after the select keyword""" + return "" + + def binary_operator_string(self, binary): + """Access uses "mod" instead of "%" """ + return binary.operator == '%' and 'mod' or binary.operator + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression._Function): + return column.label() + else: + return super(AccessCompiler, self).label_select_column(select, column, asfrom) + + function_rewrites = {'current_date': 'now', + 'current_timestamp': 'now', + 'length': 'len', + } + def visit_function(self, func): + """Access function names differ from the ANSI SQL names; rewrite common ones""" + func.name = self.function_rewrites.get(func.name, func.name) + super(AccessCompiler, self).visit_function(func) + + def for_update_clause(self, select): + """FOR UPDATE is not supported by Access; silently ignore""" + return '' + + # Strip schema + def visit_table(self, table, asfrom=False, **kwargs): + if asfrom: + return self.preparer.quote(table, table.name) + else: + return "" + + +class AccessSchemaGenerator(compiler.SchemaGenerator): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + + # install a sequence if we have an implicit IDENTITY column + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_key: + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if not column.nullable: + colspec += " NOT NULL" + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + colspec = self.preparer.format_column(column) + " counter" + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + +class AccessSchemaDropper(compiler.SchemaDropper): + def visit_index(self, index): + self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name)) + self.execute() + +class AccessDefaultRunner(base.DefaultRunner): + pass + +class AccessIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = compiler.RESERVED_WORDS.copy() + reserved_words.update(['value', 'text']) + def __init__(self, dialect): + super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + +dialect = AccessDialect +dialect.poolclass = pool.SingletonThreadPool +dialect.statement_compiler = AccessCompiler +dialect.schemagenerator = AccessSchemaGenerator +dialect.schemadropper = AccessSchemaDropper +dialect.preparer = AccessIdentifierPreparer +dialect.defaultrunner = AccessDefaultRunner diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 07f07644f2..5e1dd72bb0 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -1,21 +1,104 @@ # firebird.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +""" +Firebird backend +================ -import warnings +This module implements the Firebird backend, thru the kinterbasdb_ +DBAPI module. -from sqlalchemy import util, sql, schema, ansisql, exceptions -import sqlalchemy.engine.default as default -import sqlalchemy.types as sqltypes +Firebird dialects +----------------- + +Firebird offers two distinct dialects_ (not to be confused with the +SA ``Dialect`` thing): + +dialect 1 + This is the old syntax and behaviour, inherited from Interbase pre-6.0. + +dialect 3 + This is the newer and supported syntax, introduced in Interbase 6.0. + +From the user point of view, the biggest change is in date/time +handling: under dialect 1, there's a single kind of field, ``DATE`` +with a synonim ``DATETIME``, that holds a `timestamp` value, that is a +date with hour, minute, second. Under dialect 3 there are three kinds, +a ``DATE`` that holds a date, a ``TIME`` that holds a *time of the +day* value and a ``TIMESTAMP``, equivalent to the old ``DATE``. + +The problem is that the dialect of a Firebird database is a property +of the database itself [#]_ (that is, any single database has been +created with one dialect or the other: there is no way to change the +after creation). SQLAlchemy has a single instance of the class that +controls all the connections to a particular kind of database, so it +cannot easily differentiate between the two modes, and in particular +it **cannot** simultaneously talk with two distinct Firebird databases +with different dialects. + +By default this module is biased toward dialect 3, but you can easily +tweak it to handle dialect 1 if needed:: + + from sqlalchemy import types as sqltypes + from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names + + # Adjust the mapping of the timestamp kind + ischema_names['TIMESTAMP'] = FBDate + colspecs[sqltypes.DateTime] = FBDate, + +Other aspects may be version-specific. You can use the ``server_version_info()`` method +on the ``FBDialect`` class to do whatever is needed:: + + from sqlalchemy.databases.firebird import FBCompiler + + if engine.dialect.server_version_info(connection) < (2,0): + # Change the name of the function ``length`` to use the UDF version + # instead of ``char_length`` + FBCompiler.LENGTH_FUNCTION_NAME = 'strlen' + +Pooling connections +------------------- + +The default strategy used by SQLAlchemy to pool the database connections +in particular cases may raise an ``OperationalError`` with a message +`"object XYZ is in use"`. This happens on Firebird when there are two +connections to the database, one is using, or has used, a particular table +and the other tries to drop or alter the same table. To garantee DDL +operations success Firebird recommend doing them as the single connected user. + +In case your SA application effectively needs to do DDL operations while other +connections are active, the following setting may alleviate the problem:: + + from sqlalchemy import pool + from sqlalchemy.databases.firebird import dialect + + # Force SA to use a single connection per thread + dialect.poolclass = pool.SingletonThreadPool + + +.. [#] Well, that is not the whole story, as the client may still ask + a different (lower) dialect... + +.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html +.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb +""" + + +import datetime + +from sqlalchemy import exceptions, schema, types as sqltypes, sql, util +from sqlalchemy.engine import base, default _initialized_kb = False class FBNumeric(sqltypes.Numeric): + """Handle ``NUMERIC(precision,length)`` datatype.""" + def get_col_spec(self): if self.precision is None: return "NUMERIC" @@ -23,47 +106,107 @@ class FBNumeric(sqltypes.Numeric): return "NUMERIC(%(precision)s, %(length)s)" % { 'precision': self.precision, 'length' : self.length } + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process + + +class FBFloat(sqltypes.Float): + """Handle ``FLOAT(precision)`` datatype.""" + + def get_col_spec(self): + if not self.precision: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': self.precision} + + class FBInteger(sqltypes.Integer): + """Handle ``INTEGER`` datatype.""" + def get_col_spec(self): return "INTEGER" class FBSmallInteger(sqltypes.Smallinteger): + """Handle ``SMALLINT`` datatype.""" + def get_col_spec(self): return "SMALLINT" class FBDateTime(sqltypes.DateTime): + """Handle ``TIMESTAMP`` datatype.""" + def get_col_spec(self): return "TIMESTAMP" + def bind_processor(self, dialect): + def process(value): + if value is None or isinstance(value, datetime.datetime): + return value + else: + return datetime.datetime(year=value.year, + month=value.month, + day=value.day) + return process + class FBDate(sqltypes.DateTime): + """Handle ``DATE`` datatype.""" + def get_col_spec(self): return "DATE" -class FBText(sqltypes.TEXT): +class FBTime(sqltypes.Time): + """Handle ``TIME`` datatype.""" + def get_col_spec(self): - return "BLOB SUB_TYPE 2" + return "TIME" + + +class FBText(sqltypes.Text): + """Handle ``BLOB SUB_TYPE 1`` datatype (aka *textual* blob).""" + + def get_col_spec(self): + return "BLOB SUB_TYPE 1" class FBString(sqltypes.String): + """Handle ``VARCHAR(length)`` datatype.""" + def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} class FBChar(sqltypes.CHAR): + """Handle ``CHAR(length)`` datatype.""" + def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} class FBBinary(sqltypes.Binary): + """Handle ``BLOB SUB_TYPE 0`` datatype (aka *binary* blob).""" + def get_col_spec(self): - return "BLOB SUB_TYPE 1" + return "BLOB SUB_TYPE 0" class FBBoolean(sqltypes.Boolean): + """Handle boolean values as a ``SMALLINT`` datatype.""" + def get_col_spec(self): return "SMALLINT" @@ -72,17 +215,35 @@ colspecs = { sqltypes.Integer : FBInteger, sqltypes.Smallinteger : FBSmallInteger, sqltypes.Numeric : FBNumeric, - sqltypes.Float : FBNumeric, + sqltypes.Float : FBFloat, sqltypes.DateTime : FBDateTime, sqltypes.Date : FBDate, + sqltypes.Time : FBTime, sqltypes.String : FBString, sqltypes.Binary : FBBinary, sqltypes.Boolean : FBBoolean, - sqltypes.TEXT : FBText, + sqltypes.Text : FBText, sqltypes.CHAR: FBChar, } +ischema_names = { + 'SHORT': lambda r: FBSmallInteger(), + 'LONG': lambda r: FBInteger(), + 'QUAD': lambda r: FBFloat(), + 'FLOAT': lambda r: FBFloat(), + 'DATE': lambda r: FBDate(), + 'TIME': lambda r: FBTime(), + 'TEXT': lambda r: FBString(r['flen']), + 'INT64': lambda r: FBNumeric(precision=r['fprec'], length=r['fscale'] * -1), # This generically handles NUMERIC() + 'DOUBLE': lambda r: FBFloat(), + 'TIMESTAMP': lambda r: FBDateTime(), + 'VARYING': lambda r: FBString(r['flen']), + 'CSTRING': lambda r: FBChar(r['flen']), + 'BLOB': lambda r: r['stype']==1 and FBText() or FBBinary() + } + + def descriptor(): return {'name':'firebird', 'description':'Firebird', @@ -95,13 +256,20 @@ def descriptor(): class FBExecutionContext(default.DefaultExecutionContext): - def supports_sane_rowcount(self): - return True + pass -class FBDialect(ansisql.ANSIDialect): +class FBDialect(default.DefaultDialect): + """Firebird dialect""" + + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + max_identifier_length = 31 + preexecute_pk_sequences = True + supports_pk_autoincrement = False + def __init__(self, type_conv=200, concurrency_level=1, **kwargs): - ansisql.ANSIDialect.__init__(self, **kwargs) + default.DefaultDialect.__init__(self, **kwargs) self.type_conv = type_conv self.concurrency_level= concurrency_level @@ -110,9 +278,9 @@ class FBDialect(ansisql.ANSIDialect): import kinterbasdb return kinterbasdb dbapi = classmethod(dbapi) - + def create_connect_args(self, url): - opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) + opts = url.translate_connect_args(username='user') if opts.get('port'): opts['host'] = "%s/%s" % (opts['host'], opts['port']) del opts['port'] @@ -132,36 +300,90 @@ class FBDialect(ansisql.ANSIDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def supports_sane_rowcount(self): - return False + def server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. - def compiler(self, statement, bindparams, **kwargs): - return FBCompiler(self, statement, bindparams, **kwargs) + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ - def schemagenerator(self, *args, **kwargs): - return FBSchemaGenerator(self, *args, **kwargs) + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. This is more than enough for our purposes, + # as this is mainly (only?) used by the testsuite. + + from re import match + + fbconn = connection.connection.connection + version = fbconn.server_version + m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) + if not m: + raise exceptions.AssertionError("Could not determine version from string '%s'" % version) + return tuple([int(x) for x in m.group(5, 6, 4)]) + + def _normalize_name(self, name): + """Convert the name to lowercase if it is possible""" + + # Remove trailing spaces: FB uses a CHAR() type, + # that is padded with spaces + name = name and name.rstrip() + if name is None: + return None + elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower()): + return name.lower() + else: + return name - def schemadropper(self, *args, **kwargs): - return FBSchemaDropper(self, *args, **kwargs) + def _denormalize_name(self, name): + """Revert a *normalized* name to its uppercase equivalent""" - def defaultrunner(self, connection): - return FBDefaultRunner(connection) + if name is None: + return None + elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): + return name.upper() + else: + return name - def preparer(self): - return FBIdentifierPreparer(self) + def table_names(self, connection, schema): + """Return a list of *normalized* table names omitting system relations.""" - def max_identifier_length(self): - return 31 + s = """ + SELECT r.rdb$relation_name + FROM rdb$relations r + WHERE r.rdb$system_flag=0 + """ + return [self._normalize_name(row[0]) for row in connection.execute(s)] def has_table(self, connection, table_name, schema=None): + """Return ``True`` if the given table exists, ignoring the `schema`.""" + tblqry = """ - SELECT count(*) - FROM RDB$RELATIONS R - WHERE R.RDB$RELATION_NAME=?""" + SELECT 1 FROM rdb$database + WHERE EXISTS (SELECT rdb$relation_name + FROM rdb$relations + WHERE rdb$relation_name=?) + """ + c = connection.execute(tblqry, [self._denormalize_name(table_name)]) + row = c.fetchone() + if row is not None: + return True + else: + return False - c = connection.execute(tblqry, [table_name.upper()]) + def has_sequence(self, connection, sequence_name): + """Return ``True`` if the given sequence (generator) exists.""" + + genqry = """ + SELECT 1 FROM rdb$database + WHERE EXISTS (SELECT rdb$generator_name + FROM rdb$generators + WHERE rdb$generator_name=?) + """ + c = connection.execute(genqry, [self._denormalize_name(sequence_name)]) row = c.fetchone() - if row[0] > 0: + if row is not None: return True else: return False @@ -169,97 +391,93 @@ class FBDialect(ansisql.ANSIDialect): def is_disconnect(self, e): if isinstance(e, self.dbapi.OperationalError): return 'Unable to complete network request to host' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + return 'Invalid connection state' in str(e) else: return False def reflecttable(self, connection, table, include_columns): - #TODO: map these better - column_func = { - 14 : lambda r: sqltypes.String(r['FLEN']), # TEXT - 7 : lambda r: sqltypes.Integer(), # SHORT - 8 : lambda r: sqltypes.Integer(), # LONG - 9 : lambda r: sqltypes.Float(), # QUAD - 10 : lambda r: sqltypes.Float(), # FLOAT - 27 : lambda r: sqltypes.Float(), # DOUBLE - 35 : lambda r: sqltypes.DateTime(), # TIMESTAMP - 37 : lambda r: sqltypes.String(r['FLEN']), # VARYING - 261: lambda r: sqltypes.TEXT(), # BLOB - 40 : lambda r: sqltypes.Char(r['FLEN']), # CSTRING - 12 : lambda r: sqltypes.Date(), # DATE - 13 : lambda r: sqltypes.Time(), # TIME - 16 : lambda r: sqltypes.Numeric(precision=r['FPREC'], length=r['FSCALE'] * -1) #INT64 - } + # Query to extract the details of all the fields of the given table tblqry = """ - SELECT DISTINCT R.RDB$FIELD_NAME AS FNAME, - R.RDB$NULL_FLAG AS NULL_FLAG, - R.RDB$FIELD_POSITION, - F.RDB$FIELD_TYPE AS FTYPE, - F.RDB$FIELD_SUB_TYPE AS STYPE, - F.RDB$FIELD_LENGTH AS FLEN, - F.RDB$FIELD_PRECISION AS FPREC, - F.RDB$FIELD_SCALE AS FSCALE - FROM RDB$RELATION_FIELDS R - JOIN RDB$FIELDS F ON R.RDB$FIELD_SOURCE=F.RDB$FIELD_NAME - WHERE F.RDB$SYSTEM_FLAG=0 and R.RDB$RELATION_NAME=? - ORDER BY R.RDB$FIELD_POSITION""" + SELECT DISTINCT r.rdb$field_name AS fname, + r.rdb$null_flag AS null_flag, + t.rdb$type_name AS ftype, + f.rdb$field_sub_type AS stype, + f.rdb$field_length AS flen, + f.rdb$field_precision AS fprec, + f.rdb$field_scale AS fscale, + COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault + FROM rdb$relation_fields r + JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name + JOIN rdb$types t ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE' + WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? + ORDER BY r.rdb$field_position + """ + # Query to extract the PK/FK constrained fields of the given table keyqry = """ - SELECT SE.RDB$FIELD_NAME SENAME - FROM RDB$RELATION_CONSTRAINTS RC - JOIN RDB$INDEX_SEGMENTS SE - ON RC.RDB$INDEX_NAME=SE.RDB$INDEX_NAME - WHERE RC.RDB$CONSTRAINT_TYPE=? AND RC.RDB$RELATION_NAME=?""" + SELECT se.rdb$field_name AS fname + FROM rdb$relation_constraints rc + JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + """ + # Query to extract the details of each UK/FK of the given table fkqry = """ - SELECT RC.RDB$CONSTRAINT_NAME CNAME, - CSE.RDB$FIELD_NAME FNAME, - IX2.RDB$RELATION_NAME RNAME, - SE.RDB$FIELD_NAME SENAME - FROM RDB$RELATION_CONSTRAINTS RC - JOIN RDB$INDICES IX1 - ON IX1.RDB$INDEX_NAME=RC.RDB$INDEX_NAME - JOIN RDB$INDICES IX2 - ON IX2.RDB$INDEX_NAME=IX1.RDB$FOREIGN_KEY - JOIN RDB$INDEX_SEGMENTS CSE - ON CSE.RDB$INDEX_NAME=IX1.RDB$INDEX_NAME - JOIN RDB$INDEX_SEGMENTS SE - ON SE.RDB$INDEX_NAME=IX2.RDB$INDEX_NAME AND SE.RDB$FIELD_POSITION=CSE.RDB$FIELD_POSITION - WHERE RC.RDB$CONSTRAINT_TYPE=? AND RC.RDB$RELATION_NAME=? - ORDER BY SE.RDB$INDEX_NAME, SE.RDB$FIELD_POSITION""" + SELECT rc.rdb$constraint_name AS cname, + cse.rdb$field_name AS fname, + ix2.rdb$relation_name AS targetrname, + se.rdb$field_name AS targetfname + FROM rdb$relation_constraints rc + JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name + JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key + JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name + JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND se.rdb$field_position=cse.rdb$field_position + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + ORDER BY se.rdb$index_name, se.rdb$field_position + """ + # Heuristic-query to determine the generator associated to a PK field + genqry = """ + SELECT trigdep.rdb$depended_on_name AS fgenerator + FROM rdb$dependencies tabdep + JOIN rdb$dependencies trigdep ON (tabdep.rdb$dependent_name=trigdep.rdb$dependent_name + AND trigdep.rdb$depended_on_type=14 + AND trigdep.rdb$dependent_type=2) + JOIN rdb$triggers trig ON (trig.rdb$trigger_name=tabdep.rdb$dependent_name) + WHERE tabdep.rdb$depended_on_name=? + AND tabdep.rdb$depended_on_type=0 + AND trig.rdb$trigger_type=1 + AND tabdep.rdb$field_name=? + AND (SELECT count(*) + FROM rdb$dependencies trigdep2 + WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 + """ + + tablename = self._denormalize_name(table.name) # get primary key fields - c = connection.execute(keyqry, ["PRIMARY KEY", table.name.upper()]) - pkfields =[r['SENAME'] for r in c.fetchall()] + c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) + pkfields =[self._normalize_name(r['fname']) for r in c.fetchall()] # get all of the fields for this table + c = connection.execute(tblqry, [tablename]) - def lower_if_possible(name): - # Remove trailing spaces: FB uses a CHAR() type, - # that is padded with spaces - name = name.rstrip() - # If its composed only by upper case chars, use - # the lowered version, otherwise keep the original - # (even if stripped...) - lname = name.lower() - if lname.upper() == name and not ' ' in name: - return lname - return name - - c = connection.execute(tblqry, [table.name.upper()]) - row = c.fetchone() - if not row: - raise exceptions.NoSuchTableError(table.name) + found_table = False + while True: + row = c.fetchone() + if row is None: + break + found_table = True - while row: - name = row['FNAME'] - python_name = lower_if_possible(name) - if include_columns and python_name not in include_columns: + name = self._normalize_name(row['fname']) + if include_columns and name not in include_columns: continue - args = [python_name] + args = [name] kw = {} - # get the data types and lengths - coltype = column_func.get(row['FTYPE'], None) + # get the data type + coltype = ischema_names.get(row['ftype'].rstrip()) if coltype is None: - warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (str(row['FTYPE']), name))) + util.warn("Did not recognize type '%s' of column '%s'" % + (str(row['ftype']), name)) coltype = sqltypes.NULLTYPE else: coltype = coltype(row) @@ -268,25 +486,47 @@ class FBDialect(ansisql.ANSIDialect): # is it a primary key? kw['primary_key'] = name in pkfields - table.append_column(schema.Column(*args, **kw)) - row = c.fetchone() + # is it nullable? + kw['nullable'] = not bool(row['null_flag']) + + # does it have a default value? + if row['fdefault'] is not None: + # the value comes down as "DEFAULT 'value'" + assert row['fdefault'].startswith('DEFAULT ') + defvalue = row['fdefault'][8:] + args.append(schema.PassiveDefault(sql.text(defvalue))) + + col = schema.Column(*args, **kw) + if kw['primary_key']: + # if the PK is a single field, try to see if its linked to + # a sequence thru a trigger + if len(pkfields)==1: + genc = connection.execute(genqry, [tablename, row['fname']]) + genr = genc.fetchone() + if genr is not None: + col.sequence = schema.Sequence(self._normalize_name(genr['fgenerator'])) + + table.append_column(col) + + if not found_table: + raise exceptions.NoSuchTableError(table.name) # get the foreign keys - c = connection.execute(fkqry, ["FOREIGN KEY", table.name.upper()]) + c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) fks = {} while True: row = c.fetchone() if not row: break - cname = lower_if_possible(row['CNAME']) + cname = self._normalize_name(row['cname']) try: fk = fks[cname] except KeyError: fks[cname] = fk = ([], []) - rname = lower_if_possible(row['RNAME']) + rname = self._normalize_name(row['targetrname']) schema.Table(rname, table.metadata, autoload=True, autoload_with=connection) - fname = lower_if_possible(row['FNAME']) - refspec = rname + '.' + lower_if_possible(row['SENAME']) + fname = self._normalize_name(row['fname']) + refspec = rname + '.' + self._normalize_name(row['targetfname']) fk[0].append(fname) fk[1].append(refspec) @@ -294,33 +534,59 @@ class FBDialect(ansisql.ANSIDialect): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) def do_execute(self, cursor, statement, parameters, **kwargs): + # kinterbase does not accept a None, but wants an empty list + # when there are no arguments. cursor.execute(statement, parameters or []) def do_rollback(self, connection): + # Use the retaining feature, that keeps the transaction going connection.rollback(True) def do_commit(self, connection): + # Use the retaining feature, that keeps the transaction going connection.commit(True) -class FBCompiler(ansisql.ANSICompiler): +def _substring(s, start, length=None): + "Helper function to handle Firebird 2 SUBSTRING builtin" + + if length is None: + return "SUBSTRING(%s FROM %s)" % (s, start) + else: + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + + +class FBCompiler(sql.compiler.DefaultCompiler): """Firebird specific idiosincrasies""" + # Firebird lacks a builtin modulo operator, but there is + # an equivalent function in the ib_udf library. + operators = sql.compiler.DefaultCompiler.operators.copy() + operators.update({ + sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y) + }) + def visit_alias(self, alias, asfrom=False, **kwargs): # Override to not use the AS keyword which FB 1.5 does not like if asfrom: - return self.process(alias.original, asfrom=True) + " " + self.preparer.format_alias(alias) + return self.process(alias.original, asfrom=True, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name)) else: - return self.process(alias.original, asfrom=True) + return self.process(alias.original, **kwargs) + + functions = sql.compiler.DefaultCompiler.functions.copy() + functions['substring'] = _substring - def visit_function(self, func): - if len(func.clauses): - return super(FBCompiler, self).visit_function(func) + def function_argspec(self, func): + if func.clauses: + return self.process(func.clause_expr) else: - return func.name + return "" + + def default_from(self): + return " FROM rdb$database" - def uses_sequences_for_inserts(self): - return True + def visit_sequence(self, seq): + return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just @@ -330,22 +596,37 @@ class FBCompiler(ansisql.ANSICompiler): result = "" if select._limit: - result += " FIRST %d " % select._limit + result += "FIRST %d " % select._limit if select._offset: - result +=" SKIP %d " % select._offset + result +="SKIP %d " % select._offset if select._distinct: - result += " DISTINCT " + result += "DISTINCT " return result def limit_clause(self, select): """Already taken care of in the `get_select_precolumns` method.""" + return "" + LENGTH_FUNCTION_NAME = 'char_length' + def function_string(self, func): + """Substitute the ``length`` function. + + On newer FB there is a ``char_length`` function, while older + ones need the ``strlen`` UDF. + """ + + if func.name == 'length': + return self.LENGTH_FUNCTION_NAME + '%(expr)s' + return super(FBCompiler, self).function_string(func) + + +class FBSchemaGenerator(sql.compiler.SchemaGenerator): + """Firebird syntactic idiosincrasies""" -class FBSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() default = self.get_column_default_string(column) if default is not None: @@ -357,23 +638,30 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - self.append("CREATE GENERATOR %s" % sequence.name) + """Generate a ``CREATE GENERATOR`` statement for the sequence.""" + + self.append("CREATE GENERATOR %s" % self.preparer.format_sequence(sequence)) self.execute() -class FBSchemaDropper(ansisql.ANSISchemaDropper): +class FBSchemaDropper(sql.compiler.SchemaDropper): + """Firebird syntactic idiosincrasies""" + def visit_sequence(self, sequence): - self.append("DROP GENERATOR %s" % sequence.name) + """Generate a ``DROP GENERATOR`` statement for the sequence.""" + + self.append("DROP GENERATOR %s" % self.preparer.format_sequence(sequence)) self.execute() -class FBDefaultRunner(ansisql.ANSIDefaultRunner): - def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection) - return self.connection.execute_compiled(c).scalar() +class FBDefaultRunner(base.DefaultRunner): + """Firebird specific idiosincrasies""" def visit_sequence(self, seq): - return self.connection.execute_text("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").scalar() + """Get the next value from the sequence using ``gen_id()``.""" + + return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \ + self.dialect.identifier_preparer.format_sequence(seq)) RESERVED_WORDS = util.Set( @@ -417,12 +705,18 @@ RESERVED_WORDS = util.Set( "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) -class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): + """Install Firebird specific reserved words.""" + + reserved_words = RESERVED_WORDS + def __init__(self, dialect): super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True) - def _reserved_words(self): - return RESERVED_WORDS - dialect = FBDialect +dialect.statement_compiler = FBCompiler +dialect.schemagenerator = FBSchemaGenerator +dialect.schemadropper = FBSchemaDropper +dialect.defaultrunner = FBDefaultRunner +dialect.preparer = FBIdentifierPreparer diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 93f47de15e..1b3b3838ab 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -76,26 +76,11 @@ ref_constraints = Table("referential_constraints", ischema, Column("update_rule", String), Column("delete_rule", String), schema="information_schema") - -class ISchema(object): - def __init__(self, engine): - self.engine = engine - self.cache = {} - def __getattr__(self, name): - if name not in self.cache: - # This is a bit of a hack. - # It would probably be better to have a dict - # with just the information_schema tables at - # the module level, so as to avoid returning - # unrelated objects that happen to be named - # 'gen_*' - try: - gen_tbl = globals()['gen_'+name] - except KeyError: - raise exceptions.ArgumentError('information_schema table %s not found' % name) - self.cache[name] = gen_tbl.toengine(self.engine) - return self.cache[name] + +def table_names(connection, schema): + s = select([tables.c.table_name], tables.c.table_schema==schema) + return [row[0] for row in connection.execute(s)] def reflecttable(connection, table, include_columns, ischema_names): diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index f3a6cf60e7..2e1f19de96 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -1,15 +1,17 @@ -# coding: gbk # informix.py -# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005,2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com +# +# coding: gbk # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, warnings +import datetime -from sqlalchemy import sql, schema, ansisql, exceptions, pool -import sqlalchemy.engine.default as default -import sqlalchemy.types as sqltypes +from sqlalchemy import sql, schema, exceptions, pool, util +from sqlalchemy.sql import compiler +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes # for offset @@ -18,7 +20,7 @@ class informix_cursor(object): def __init__( self , con ): self.__cursor = con.cursor() self.rowcount = 0 - + def offset( self , n ): if n > 0: self.fetchmany( n ) @@ -27,13 +29,13 @@ class informix_cursor(object): self.rowcount = 0 else: self.rowcount = self.__cursor.rowcount - + def execute( self , sql , params ): if params is None or len( params ) == 0: params = [] - + return self.__cursor.execute( sql , params ) - + def __getattr__( self , name ): if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ): return getattr( self.__cursor , name ) @@ -44,7 +46,7 @@ class InfoNumeric(sqltypes.Numeric): return 'NUMERIC' else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} - + class InfoInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -60,29 +62,35 @@ class InfoDate(sqltypes.Date): class InfoDateTime(sqltypes.DateTime ): def get_col_spec(self): return "DATETIME YEAR TO SECOND" - - def convert_bind_param(self, value, dialect): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value + + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process class InfoTime(sqltypes.Time ): def get_col_spec(self): return "DATETIME HOUR TO SECOND" - def convert_bind_param(self, value, dialect): - if value is not None: - if value.microsecond: - value = value.replace( microsecond = 0 ) - return value - - def convert_result_value(self, value, dialect): - if isinstance( value , datetime.datetime ): - return value.time() - else: - return value - + def bind_processor(self, dialect): + def process(value): + if value is not None: + if value.microsecond: + value = value.replace( microsecond = 0 ) + return value + return process + + def result_processor(self, dialect): + def process(value): + if isinstance( value , datetime.datetime ): + return value.time() + else: + return value + return process + class InfoText(sqltypes.String): def get_col_spec(self): return "VARCHAR(255)" @@ -90,38 +98,47 @@ class InfoText(sqltypes.String): class InfoString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} - - def convert_bind_param( self , value , dialect ): - if value == '': - return None - else: - return value + + def bind_processor(self, dialect): + def process(value): + if value == '': + return None + else: + return value + return process class InfoChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class InfoBinary(sqltypes.Binary): def get_col_spec(self): return "BYTE" + class InfoBoolean(sqltypes.Boolean): default_type = 'NUM' def get_col_spec(self): return "SMALLINT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process - colspecs = { sqltypes.Integer : InfoInteger, sqltypes.Smallinteger : InfoSmallInteger, @@ -133,32 +150,32 @@ colspecs = { sqltypes.String : InfoString, sqltypes.Binary : InfoBinary, sqltypes.Boolean : InfoBoolean, - sqltypes.TEXT : InfoText, + sqltypes.Text : InfoText, sqltypes.CHAR: InfoChar, } ischema_names = { - 0 : InfoString, # CHAR + 0 : InfoString, # CHAR 1 : InfoSmallInteger, # SMALLINT - 2 : InfoInteger, # INT + 2 : InfoInteger, # INT 3 : InfoNumeric, # Float 3 : InfoNumeric, # SmallFloat - 5 : InfoNumeric, # DECIMAL + 5 : InfoNumeric, # DECIMAL 6 : InfoInteger, # Serial - 7 : InfoDate, # DATE + 7 : InfoDate, # DATE 8 : InfoNumeric, # MONEY 10 : InfoDateTime, # DATETIME - 11 : InfoBinary, # BYTE - 12 : InfoText, # TEXT - 13 : InfoString, # VARCHAR - 15 : InfoString, # NCHAR - 16 : InfoString, # NVARCHAR + 11 : InfoBinary, # BYTE + 12 : InfoText, # TEXT + 13 : InfoString, # VARCHAR + 15 : InfoString, # NCHAR + 16 : InfoString, # NVARCHAR 17 : InfoInteger, # INT8 18 : InfoInteger, # Serial8 43 : InfoString, # LVARCHAR - -1 : InfoBinary, # BLOB - -1 : InfoText, # CLOB + -1 : InfoBinary, # BLOB + -1 : InfoText, # CLOB } def descriptor(): @@ -187,23 +204,21 @@ class InfoExecutionContext(default.DefaultExecutionContext): def create_cursor( self ): return informix_cursor( self.connection.connection ) - -class InfoDialect(ansisql.ANSIDialect): - + +class InfoDialect(default.DefaultDialect): + default_paramstyle = 'qmark' + # for informix 7.31 + max_identifier_length = 18 + def __init__(self, use_ansi=True,**kwargs): self.use_ansi = use_ansi - ansisql.ANSIDialect.__init__(self, **kwargs) - self.paramstyle = 'qmark' + default.DefaultDialect.__init__(self, **kwargs) def dbapi(cls): import informixdb return informixdb dbapi = classmethod(dbapi) - def max_identifier_length( self ): - # for informix 7.31 - return 18 - def is_disconnect(self, e): if isinstance(e, self.dbapi.OperationalError): return 'closed the connection' in str(e) or 'connection not open' in str(e) @@ -214,7 +229,7 @@ class InfoDialect(ansisql.ANSIDialect): cu = connect.cursor() cu.execute( 'SET LOCK MODE TO WAIT' ) #cu.execute( 'SET ISOLATION TO REPEATABLE READ' ) - + def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) @@ -223,36 +238,28 @@ class InfoDialect(ansisql.ANSIDialect): dsn = '%s@%s' % ( url.database , url.host ) else: dsn = url.database - + if url.username: opt = { 'user':url.username , 'password': url.password } else: opt = {} - + return ([dsn,], opt ) - + def create_execution_context(self , *args, **kwargs): return InfoExecutionContext(self, *args, **kwargs) - + def oid_column_name(self,column): return "rowid" - - def preparer(self): - return InfoIdentifierPreparer(self) - - def compiler(self, statement, bindparams, **kwargs): - return InfoCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return InfoSchemaGenerator( self , *args, **kwargs) - - def schemadropper(self, *args, **params): - return InfoSchemaDroper( self , *args , **params) - + + def table_names(self, connection, schema): + s = "select tabname from systables" + return [row[0] for row in connection.execute(s)] + def has_table(self, connection, table_name,schema=None): cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() ) return bool( cursor.fetchone() is not None ) - + def reflecttable(self, connection, table, include_columns): c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() ) rows = c.fetchall() @@ -271,12 +278,12 @@ class InfoDialect(ansisql.ANSIDialect): raise exceptions.AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name) c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3 - where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? + where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t3.tabid = t2.tabid and t3.colno = t1.colno order by t1.colno""", table.name.lower(), owner ) rows = c.fetchall() - - if not rows: + + if not rows: raise exceptions.NoSuchTableError(table.name) for name , colattr , collength , default , colno in rows: @@ -286,11 +293,11 @@ class InfoDialect(ansisql.ANSIDialect): # in 7.31, coltype = 0x000 # ^^-- column type - # ^-- 1 not null , 0 null + # ^-- 1 not null , 0 null nullable , coltype = divmod( colattr , 256 ) if coltype not in ( 0 , 13 ) and default: default = default.split()[-1] - + if coltype == 0 or coltype == 13: # char , varchar coltype = ischema_names.get(coltype, InfoString)(collength) if default: @@ -304,28 +311,29 @@ class InfoDialect(ansisql.ANSIDialect): try: coltype = ischema_names[coltype] except KeyError: - warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name))) + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, name)) coltype = sqltypes.NULLTYPE - + colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - + table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs)) # FK - c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , - t4.colname as local_column , t7.tabname as remote_table , - t6.colname as remote_column - from sysconstraints as t1 , systables as t2 , - sysindexes as t3 , syscolumns as t4 , - sysreferences as t5 , syscolumns as t6 , systables as t7 , + c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , + t4.colname as local_column , t7.tabname as remote_table , + t6.colname as remote_column + from sysconstraints as t1 , systables as t2 , + sysindexes as t3 , syscolumns as t4 , + sysreferences as t5 , syscolumns as t6 , systables as t7 , sysconstraints as t8 , sysindexes as t9 where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'R' and t3.tabid = t2.tabid and t3.idxname = t1.idxname and t4.tabid = t2.tabid and t4.colno = t3.part1 and t5.constrid = t1.constrid and t8.constrid = t5.primary - and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname + and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname and t7.tabid = t5.ptabid""", table.name.lower(), owner ) rows = c.fetchall() fks = {} @@ -341,15 +349,15 @@ class InfoDialect(ansisql.ANSIDialect): fk[0].append(local_column) if refspec not in fk[1]: fk[1].append(refspec) - + for name, value in fks.iteritems(): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1] , None )) - + # PK - c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , - t4.colname as local_column - from sysconstraints as t1 , systables as t2 , - sysindexes as t3 , syscolumns as t4 + c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , + t4.colname as local_column + from sysconstraints as t1 , systables as t2 , + sysindexes as t3 , syscolumns as t4 where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'P' and t3.tabid = t2.tabid and t3.idxname = t1.idxname and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower(), owner ) @@ -357,18 +365,19 @@ class InfoDialect(ansisql.ANSIDialect): for cons_name, cons_type, local_column in rows: table.primary_key.add( table.c[local_column] ) -class InfoCompiler(ansisql.ANSICompiler): - """Info compiler modifies the lexical structure of Select statements to work under +class InfoCompiler(compiler.DefaultCompiler): + """Info compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - def __init__(self, dialect, statement, parameters=None, **kwargs): + + def __init__(self, *args, **kwargs): self.limit = 0 self.offset = 0 - - ansisql.ANSICompiler.__init__( self , dialect , statement , parameters , **kwargs ) - + + compiler.DefaultCompiler.__init__( self , *args, **kwargs ) + def default_from(self): return " from systables where tabname = 'systables' " - + def get_select_precolumns( self , select ): s = select._distinct and "DISTINCT " or "" # only has limit @@ -378,38 +387,30 @@ class InfoCompiler(ansisql.ANSICompiler): else: s += "" return s - + def visit_select(self, select): if select._offset: self.offset = select._offset self.limit = select._limit or 0 # the column in order by clause must in select too - + def __label( c ): try: return c._label.lower() except: return '' - - # TODO: dont modify the original select, generate a new one + + # TODO: dont modify the original select, generate a new one a = [ __label(c) for c in select._raw_columns ] - for c in select.order_by_clause.clauses: + for c in select._order_by_clause.clauses: if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid': select.append_column( c ) - - return ansisql.ANSICompiler.visit_select(self, select) - + + return compiler.DefaultCompiler.visit_select(self, select) + def limit_clause(self, select): return "" - def __visit_label(self, label): - if len(self.select_stack): - self.typemap.setdefault(label.name.lower(), label.obj.type) - if self.strings[label.obj]: - self.strings[label] = self.strings[label.obj] + " AS " + label.name - else: - self.strings[label] = None - def visit_function( self , func ): if func.name.lower() == 'current_date': return "today" @@ -418,8 +419,8 @@ class InfoCompiler(ansisql.ANSICompiler): elif func.name.lower() in ( 'current_timestamp' , 'now' ): return "CURRENT YEAR TO SECOND" else: - return ansisql.ANSICompiler.visit_function( self , func ) - + return compiler.DefaultCompiler.visit_function( self , func ) + def visit_clauselist(self, list): try: li = [ c for c in list.clauses if c.name != 'oid' ] @@ -427,7 +428,7 @@ class InfoCompiler(ansisql.ANSICompiler): li = [ c for c in list.clauses ] return ', '.join([s for s in [self.process(c) for c in li] if s is not None]) -class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): +class InfoSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, first_pk=False): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ @@ -435,41 +436,41 @@ class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): colspec += " SERIAL" self.has_serial = True else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default - + if not column.nullable: colspec += " NOT NULL" - + return colspec - + def post_create_table(self, table): if hasattr( self , 'has_serial' ): del self.has_serial return '' - + def visit_primary_key_constraint(self, constraint): # for informix 7.31 not support constraint name name = constraint.name constraint.name = None super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint) constraint.name = name - + def visit_unique_constraint(self, constraint): # for informix 7.31 not support constraint name name = constraint.name constraint.name = None super(InfoSchemaGenerator, self).visit_unique_constraint(constraint) constraint.name = name - + def visit_foreign_key_constraint( self , constraint ): if constraint.name is not None: constraint.use_alter = True else: super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint ) - + def define_foreign_key(self, constraint): # for informix 7.31 not support constraint name if constraint.use_alter: @@ -488,20 +489,21 @@ class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): return super(InfoSchemaGenerator, self).visit_index(index) -class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class InfoIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") - - def _fold_identifier_case(self, value): - return value.lower() - - def _requires_quotes(self, value, case_sensitive): + + def _requires_quotes(self, value): return False -class InfoSchemaDroper(ansisql.ANSISchemaDropper): +class InfoSchemaDropper(compiler.SchemaDropper): def drop_foreignkey(self, constraint): if constraint.name is not None: - super( InfoSchemaDroper , self ).drop_foreignkey( constraint ) + super( InfoSchemaDropper , self ).drop_foreignkey( constraint ) dialect = InfoDialect poolclass = pool.SingletonThreadPool +dialect.statement_compiler = InfoCompiler +dialect.schemagenerator = InfoSchemaGenerator +dialect.schemadropper = InfoSchemaDropper +dialect.preparer = InfoIdentifierPreparer diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/databases/maxdb.py new file mode 100644 index 0000000000..23ff1f4a00 --- /dev/null +++ b/lib/sqlalchemy/databases/maxdb.py @@ -0,0 +1,1109 @@ +# maxdb.py +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the MaxDB database. + +TODO: More module docs! MaxDB support is currently experimental. + +Overview +-------- + +The ``maxdb`` dialect is **experimental** and has only been tested on 7.6.03.007 +and 7.6.00.037. Of these, **only 7.6.03.007 will work** with SQLAlchemy's ORM. +The earlier version has severe ``LEFT JOIN`` limitations and will return +incorrect results from even very simple ORM queries. + +Only the native Python DB-API is currently supported. ODBC driver support +is a future enhancement. + +Connecting +---------- + +The username is case-sensitive. If you usually connect to the +database with sqlcli and other tools in lower case, you likely need to +use upper case for DB-API. + +Implementation Notes +-------------------- + +Also check the DatabaseNotes page on the wiki for detailed information. + +With the 7.6.00.37 driver and Python 2.5, it seems that all DB-API +generated exceptions are broken and can cause Python to crash. + +For 'somecol.in_([])' to work, the IN operator's generation must be changed +to cast 'NULL' to a numeric, i.e. NUM(NULL). The DB-API doesn't accept a +bind parameter there, so that particular generation must inline the NULL value, +which depends on [ticket:807]. + +The DB-API is very picky about where bind params may be used in queries. + +Bind params for some functions (e.g. MOD) need type information supplied. +The dialect does not yet do this automatically. + +Max will occasionally throw up 'bad sql, compile again' exceptions for +perfectly valid SQL. The dialect does not currently handle these, more +research is needed. + +MaxDB 7.5 and Sap DB <= 7.4 reportedly do not support schemas. A very +slightly different version of this dialect would be required to support +those versions, and can easily be added if there is demand. Some other +required components such as an Max-aware 'old oracle style' join compiler +(thetas with (+) outer indicators) are already done and available for +integration- email the devel list if you're interested in working on +this. +""" + +import datetime, itertools, re + +from sqlalchemy import exceptions, schema, sql, util +from sqlalchemy.sql import operators as sql_operators, expression as sql_expr +from sqlalchemy.sql import compiler, visitors +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import types as sqltypes + + +__all__ = [ + 'MaxString', 'MaxUnicode', 'MaxChar', 'MaxText', 'MaxInteger', + 'MaxSmallInteger', 'MaxNumeric', 'MaxFloat', 'MaxTimestamp', + 'MaxDate', 'MaxTime', 'MaxBoolean', 'MaxBlob', + ] + + +class _StringType(sqltypes.String): + _type = None + + def __init__(self, length=None, encoding=None, **kw): + super(_StringType, self).__init__(length=length, **kw) + self.encoding = encoding + + def get_col_spec(self): + if self.length is None: + spec = 'LONG' + else: + spec = '%s(%s)' % (self._type, self.length) + + if self.encoding is not None: + spec = ' '.join([spec, self.encoding.upper()]) + return spec + + def bind_processor(self, dialect): + if self.encoding == 'unicode': + return None + else: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + else: + return value + return process + + def result_processor(self, dialect): + def process(value): + while True: + if value is None: + return None + elif isinstance(value, unicode): + return value + elif isinstance(value, str): + if self.convert_unicode or dialect.convert_unicode: + return value.decode(dialect.encoding) + else: + return value + elif hasattr(value, 'read'): + # some sort of LONG, snarf and retry + value = value.read(value.remainingLength()) + continue + else: + # unexpected type, return as-is + return value + return process + + +class MaxString(_StringType): + _type = 'VARCHAR' + + def __init__(self, *a, **kw): + super(MaxString, self).__init__(*a, **kw) + + +class MaxUnicode(_StringType): + _type = 'VARCHAR' + + def __init__(self, length=None, **kw): + super(MaxUnicode, self).__init__(length=length, encoding='unicode') + + +class MaxChar(_StringType): + _type = 'CHAR' + + +class MaxText(_StringType): + _type = 'LONG' + + def __init__(self, *a, **kw): + super(MaxText, self).__init__(*a, **kw) + + def get_col_spec(self): + spec = 'LONG' + if self.encoding is not None: + spec = ' '.join((spec, self.encoding)) + elif self.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + + +class MaxInteger(sqltypes.Integer): + def get_col_spec(self): + return 'INTEGER' + + +class MaxSmallInteger(MaxInteger): + def get_col_spec(self): + return 'SMALLINT' + + +class MaxNumeric(sqltypes.Numeric): + """The FIXED (also NUMERIC, DECIMAL) data type.""" + + def __init__(self, precision=None, length=None, **kw): + kw.setdefault('asdecimal', True) + super(MaxNumeric, self).__init__(length=length, precision=precision, + **kw) + + def bind_processor(self, dialect): + return None + + def get_col_spec(self): + if self.length and self.precision: + return 'FIXED(%s, %s)' % (self.precision, self.length) + elif self.precision: + return 'FIXED(%s)' % self.precision + else: + return 'INTEGER' + + +class MaxFloat(sqltypes.Float): + """The FLOAT data type.""" + + def get_col_spec(self): + if self.precision is None: + return 'FLOAT' + else: + return 'FLOAT(%s)' % (self.precision,) + + +class MaxTimestamp(sqltypes.DateTime): + def get_col_spec(self): + return 'TIMESTAMP' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y%m%d%H%M%S" + ("%06u" % ms)) + elif dialect.datetimeformat == 'iso': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms)) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[4:6], value[6:8], + value[8:10], value[10:12], value[12:14], + value[14:])]) + elif dialect.datetimeformat == 'iso': + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[5:7], value[8:10], + value[11:13], value[14:16], value[17:19], + value[20:])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxDate(sqltypes.Date): + def get_col_spec(self): + return 'DATE' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%Y%m%d") + elif dialect.datetimeformat == 'iso': + return value.strftime("%Y-%m-%d") + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + return datetime.date( + *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) + elif dialect.datetimeformat == 'iso': + return datetime.date( + *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxTime(sqltypes.Time): + def get_col_spec(self): + return 'TIME' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%H%M%S") + elif dialect.datetimeformat == 'iso': + return value.strftime("%H-%M-%S") + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + t = datetime.time( + *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) + return t + elif dialect.datetimeformat == 'iso': + return datetime.time( + *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxBoolean(sqltypes.Boolean): + def get_col_spec(self): + return 'BOOLEAN' + + +class MaxBlob(sqltypes.Binary): + def get_col_spec(self): + return 'LONG BYTE' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return str(value) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return value.read(value.remainingLength()) + return process + + +colspecs = { + sqltypes.Integer: MaxInteger, + sqltypes.Smallinteger: MaxSmallInteger, + sqltypes.Numeric: MaxNumeric, + sqltypes.Float: MaxFloat, + sqltypes.DateTime: MaxTimestamp, + sqltypes.Date: MaxDate, + sqltypes.Time: MaxTime, + sqltypes.String: MaxString, + sqltypes.Binary: MaxBlob, + sqltypes.Boolean: MaxBoolean, + sqltypes.Text: MaxText, + sqltypes.CHAR: MaxChar, + sqltypes.TIMESTAMP: MaxTimestamp, + sqltypes.BLOB: MaxBlob, + sqltypes.Unicode: MaxUnicode, + } + +ischema_names = { + 'boolean': MaxBoolean, + 'char': MaxChar, + 'character': MaxChar, + 'date': MaxDate, + 'fixed': MaxNumeric, + 'float': MaxFloat, + 'int': MaxInteger, + 'integer': MaxInteger, + 'long binary': MaxBlob, + 'long unicode': MaxText, + 'long': MaxText, + 'long': MaxText, + 'smallint': MaxSmallInteger, + 'time': MaxTime, + 'timestamp': MaxTimestamp, + 'varchar': MaxString, + } + + +class MaxDBExecutionContext(default.DefaultExecutionContext): + def post_exec(self): + # DB-API bug: if there were any functions as values, + # then do another select and pull CURRVAL from the + # autoincrement column's implicit sequence... ugh + if self.compiled.isinsert and not self.executemany: + table = self.compiled.statement.table + index, serial_col = _autoserial_column(table) + + if serial_col and (not self.compiled._safeserial or + not(self._last_inserted_ids) or + self._last_inserted_ids[index] in (None, 0)): + if table.schema: + sql = "SELECT %s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + else: + sql = "SELECT CURRENT_SCHEMA.%s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + + if self.connection.engine._should_log_info: + self.connection.engine.logger.info(sql) + rs = self.cursor.execute(sql) + id = rs.fetchone()[0] + + if self.connection.engine._should_log_debug: + self.connection.engine.logger.debug([id]) + if not self._last_inserted_ids: + # This shouldn't ever be > 1? Right? + self._last_inserted_ids = \ + [None] * len(table.primary_key.columns) + self._last_inserted_ids[index] = id + + super(MaxDBExecutionContext, self).post_exec() + + def get_result_proxy(self): + if self.cursor.description is not None: + for column in self.cursor.description: + if column[1] in ('Long Binary', 'Long', 'Long Unicode'): + return MaxDBResultProxy(self) + return engine_base.ResultProxy(self) + + +class MaxDBCachedColumnRow(engine_base.RowProxy): + """A RowProxy that only runs result_processors once per column.""" + + def __init__(self, parent, row): + super(MaxDBCachedColumnRow, self).__init__(parent, row) + self.columns = {} + self._row = row + self._parent = parent + + def _get_col(self, key): + if key not in self.columns: + self.columns[key] = self._parent._get_col(self._row, key) + return self.columns[key] + + def __iter__(self): + for i in xrange(len(self._row)): + yield self._get_col(i) + + def __repr__(self): + return repr(list(self)) + + def __eq__(self, other): + return ((other is self) or + (other == tuple([self._get_col(key) + for key in xrange(len(self._row))]))) + def __getitem__(self, key): + if isinstance(key, slice): + indices = key.indices(len(self._row)) + return tuple([self._get_col(i) for i in xrange(*indices)]) + else: + return self._get_col(key) + + def __getattr__(self, name): + try: + return self._get_col(name) + except KeyError: + raise AttributeError(name) + + +class MaxDBResultProxy(engine_base.ResultProxy): + _process_row = MaxDBCachedColumnRow + + +class MaxDBDialect(default.DefaultDialect): + supports_alter = True + supports_unicode_statements = True + max_identifier_length = 32 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + preexecute_sequences = True + + # MaxDB-specific + datetimeformat = 'internal' + + def __init__(self, _raise_known_sql_errors=False, **kw): + super(MaxDBDialect, self).__init__(**kw) + self._raise_known = _raise_known_sql_errors + + if self.dbapi is None: + self.dbapi_type_map = {} + else: + self.dbapi_type_map = { + 'Long Binary': MaxBlob(), + 'Long byte_t': MaxBlob(), + 'Long Unicode': MaxText(), + 'Timestamp': MaxTimestamp(), + 'Date': MaxDate(), + 'Time': MaxTime(), + datetime.datetime: MaxTimestamp(), + datetime.date: MaxDate(), + datetime.time: MaxTime(), + } + + def dbapi(cls): + from sapdb import dbapi as _dbapi + return _dbapi + dbapi = classmethod(dbapi) + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + return [], opts + + def type_descriptor(self, typeobj): + if isinstance(typeobj, type): + typeobj = typeobj() + if isinstance(typeobj, sqltypes.Unicode): + return typeobj.adapt(MaxUnicode) + else: + return sqltypes.adapt_type(typeobj, colspecs) + + def create_execution_context(self, connection, **kw): + return MaxDBExecutionContext(self, connection, **kw) + + def do_execute(self, cursor, statement, parameters, context=None): + res = cursor.execute(statement, parameters) + if isinstance(res, int) and context is not None: + context._rowcount = res + + def do_release_savepoint(self, connection, name): + # Does MaxDB truly support RELEASE SAVEPOINT ? All my attempts + # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS + # BEGIN SQLSTATE: I7065" + # Note that ROLLBACK TO works fine. In theory, a RELEASE should + # just free up some transactional resources early, before the overall + # COMMIT/ROLLBACK so omitting it should be relatively ok. + pass + + def get_default_schema_name(self, connection): + try: + return self._default_schema_name + except AttributeError: + name = self.identifier_preparer._normalize_name( + connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) + self._default_schema_name = name + return name + + def has_table(self, connection, table_name, schema=None): + denormalize = self.identifier_preparer._denormalize_name + bind = [denormalize(table_name)] + if schema is None: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME=? AND" + " TABLES.SCHEMANAME=CURRENT_SCHEMA ") + else: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME = ? AND" + " TABLES.SCHEMANAME=? ") + bind.append(denormalize(schema)) + + rp = connection.execute(sql, bind) + found = bool(rp.fetchone()) + rp.close() + return found + + def table_names(self, connection, schema): + if schema is None: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=CURRENT_SCHEMA ") + rs = connection.execute(sql) + else: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=? ") + matchname = self.identifier_preparer._denormalize_name(schema) + rs = connection.execute(sql, matchname) + normalize = self.identifier_preparer._normalize_name + return [normalize(row[0]) for row in rs] + + def reflecttable(self, connection, table, include_columns): + denormalize = self.identifier_preparer._denormalize_name + normalize = self.identifier_preparer._normalize_name + + st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' + ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' + 'FROM COLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY POS') + + fk = ('SELECT COLUMNNAME, FKEYNAME, ' + ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' + ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' + ' THEN 1 ELSE 0 END) AS in_schema ' + 'FROM FOREIGNKEYCOLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY FKEYNAME ') + + params = [denormalize(table.name)] + if not table.schema: + st = st % 'CURRENT_SCHEMA' + fk = fk % 'CURRENT_SCHEMA' + else: + st = st % '?' + fk = fk % '?' + params.append(denormalize(table.schema)) + + rows = connection.execute(st, params).fetchall() + if not rows: + raise exceptions.NoSuchTableError(table.fullname) + + include_columns = util.Set(include_columns or []) + + for row in rows: + (name, mode, col_type, encoding, length, scale, + nullable, constant_def, func_def) = row + + name = normalize(name) + + if include_columns and name not in include_columns: + continue + + type_args, type_kw = [], {} + if col_type == 'FIXED': + type_args = length, scale + # Convert FIXED(10) DEFAULT SERIAL to our Integer + if (scale == 0 and + func_def is not None and func_def.startswith('SERIAL')): + col_type = 'INTEGER' + type_args = length, + elif col_type in 'FLOAT': + type_args = length, + elif col_type in ('CHAR', 'VARCHAR'): + type_args = length, + type_kw['encoding'] = encoding + elif col_type == 'LONG': + type_kw['encoding'] = encoding + + try: + type_cls = ischema_names[col_type.lower()] + type_instance = type_cls(*type_args, **type_kw) + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (col_type, name)) + type_instance = sqltypes.NullType + + col_kw = {'autoincrement': False} + col_kw['nullable'] = (nullable == 'YES') + col_kw['primary_key'] = (mode == 'KEY') + + if func_def is not None: + if func_def.startswith('SERIAL'): + if col_kw['primary_key']: + # No special default- let the standard autoincrement + # support handle SERIAL pk columns. + col_kw['autoincrement'] = True + else: + # strip current numbering + col_kw['default'] = schema.PassiveDefault( + sql.text('SERIAL')) + col_kw['autoincrement'] = True + else: + col_kw['default'] = schema.PassiveDefault( + sql.text(func_def)) + elif constant_def is not None: + col_kw['default'] = schema.PassiveDefault(sql.text( + "'%s'" % constant_def.replace("'", "''"))) + + table.append_column(schema.Column(name, type_instance, **col_kw)) + + fk_sets = itertools.groupby(connection.execute(fk, params), + lambda row: row.FKEYNAME) + for fkeyname, fkey in fk_sets: + fkey = list(fkey) + if include_columns: + key_cols = util.Set([r.COLUMNNAME for r in fkey]) + if key_cols != include_columns: + continue + + columns, referants = [], [] + quote = self.identifier_preparer._maybe_quote_identifier + + for row in fkey: + columns.append(normalize(row.COLUMNNAME)) + if table.schema or not row.in_schema: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFSCHEMANAME', 'REFTABLENAME', + 'REFCOLUMNNAME')])) + else: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) + + constraint_kw = {'name': fkeyname.lower()} + if fkey[0].RULE is not None: + rule = fkey[0].RULE + if rule.startswith('DELETE '): + rule = rule[7:] + constraint_kw['ondelete'] = rule + + table_kw = {} + if table.schema or not row.in_schema: + table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) + + ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), + table_kw.get('schema')) + if ref_key not in table.metadata.tables: + schema.Table(normalize(fkey[0].REFTABLENAME), + table.metadata, + autoload=True, autoload_with=connection, + **table_kw) + + constraint = schema.ForeignKeyConstraint(columns, referants, + **constraint_kw) + table.append_constraint(constraint) + + def has_sequence(self, connection, name): + # [ticket:726] makes this schema-aware. + denormalize = self.identifier_preparer._denormalize_name + sql = ("SELECT sequence_name FROM SEQUENCES " + "WHERE SEQUENCE_NAME=? ") + + rp = connection.execute(sql, denormalize(name)) + found = bool(rp.fetchone()) + rp.close() + return found + + +class MaxDBCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() + operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) + + function_conversion = { + 'CURRENT_DATE': 'DATE', + 'CURRENT_TIME': 'TIME', + 'CURRENT_TIMESTAMP': 'TIMESTAMP', + } + + # These functions must be written without parens when called with no + # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL' + bare_functions = util.Set([ + 'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP', + 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP', + 'UTCDATE', 'UTCDIFF']) + + def default_from(self): + return ' FROM DUAL' + + def for_update_clause(self, select): + clause = select.for_update + if clause is True: + return " WITH LOCK EXCLUSIVE" + elif clause is None: + return "" + elif clause == "read": + return " WITH LOCK" + elif clause == "ignore": + return " WITH LOCK (IGNORE) EXCLUSIVE" + elif clause == "nowait": + return " WITH LOCK (NOWAIT) EXCLUSIVE" + elif isinstance(clause, basestring): + return " WITH LOCK %s" % clause.upper() + elif not clause: + return "" + else: + return " WITH LOCK EXCLUSIVE" + + def apply_function_parens(self, func): + if func.name.upper() in self.bare_functions: + return len(func.clauses) > 0 + else: + return True + + def visit_function(self, fn, **kw): + transform = self.function_conversion.get(fn.name.upper(), None) + if transform: + fn = fn._clone() + fn.name = transform + return super(MaxDBCompiler, self).visit_function(fn, **kw) + + def visit_cast(self, cast, **kwargs): + # MaxDB only supports casts * to NUMERIC, * to VARCHAR or + # date/time to VARCHAR. Casts of LONGs will fail. + if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)): + return "NUM(%s)" % self.process(cast.clause) + elif isinstance(cast.type, sqltypes.String): + return "CHR(%s)" % self.process(cast.clause) + else: + return self.process(cast.clause) + + def visit_sequence(self, sequence): + if sequence.optional: + return None + else: + return (self.dialect.identifier_preparer.format_sequence(sequence) + + ".NEXTVAL") + + class ColumnSnagger(visitors.ClauseVisitor): + def __init__(self): + self.count = 0 + self.column = None + def visit_column(self, column): + self.column = column + self.count += 1 + + def _find_labeled_columns(self, columns, use_labels=False): + labels = {} + for column in columns: + if isinstance(column, basestring): + continue + snagger = self.ColumnSnagger() + snagger.traverse(column) + if snagger.count == 1: + if isinstance(column, sql_expr._Label): + labels[unicode(snagger.column)] = column.name + elif use_labels: + labels[unicode(snagger.column)] = column._label + + return labels + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # ORDER BY clauses in DISTINCT queries must reference aliased + # inner columns by alias name, not true column name. + if order_by and getattr(select, '_distinct', False): + labels = self._find_labeled_columns(select.inner_columns, + select.use_labels) + if labels: + for needs_alias in labels.keys(): + r = re.compile(r'(^| )(%s)(,| |$)' % + re.escape(needs_alias)) + order_by = r.sub((r'\1%s\3' % labels[needs_alias]), + order_by) + + # No ORDER BY in subqueries. + if order_by: + if self.is_subquery(select): + # It's safe to simply drop the ORDER BY if there is no + # LIMIT. Right? Other dialects seem to get away with + # dropping order. + if select._limit: + raise exceptions.InvalidRequestError( + "MaxDB does not support ORDER BY in subqueries") + else: + return "" + return " ORDER BY " + order_by + else: + return "" + + def get_select_precolumns(self, select): + # Convert a subquery's LIMIT to TOP + sql = select._distinct and 'DISTINCT ' or '' + if self.is_subquery(select) and select._limit: + if select._offset: + raise exceptions.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + sql += 'TOP %s ' % select._limit + return sql + + def limit_clause(self, select): + # The docs say offsets are supported with LIMIT. But they're not. + # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate? + if self.is_subquery(select): + # sub queries need TOP + return '' + elif select._offset: + raise exceptions.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + else: + return ' \n LIMIT %s' % (select._limit,) + + def visit_insert(self, insert): + self.isinsert = True + self._safeserial = True + + colparams = self._get_colparams(insert) + for value in (insert.parameters or {}).itervalues(): + if isinstance(value, sql_expr._Function): + self._safeserial = False + break + + return ''.join(('INSERT INTO ', + self.preparer.format_table(insert.table), + ' (', + ', '.join([self.preparer.format_column(c[0]) + for c in colparams]), + ') VALUES (', + ', '.join([c[1] for c in colparams]), + ')')) + + +class MaxDBDefaultRunner(engine_base.DefaultRunner): + def visit_sequence(self, seq): + if seq.optional: + return None + return self.execute_string("SELECT %s.NEXTVAL FROM DUAL" % ( + self.dialect.identifier_preparer.format_sequence(seq))) + + +class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = util.Set([ + 'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha', + 'alter', 'any', 'ascii', 'asin', 'atan', 'atan2', 'avg', 'binary', + 'bit', 'boolean', 'byte', 'case', 'ceil', 'ceiling', 'char', + 'character', 'check', 'chr', 'column', 'concat', 'constraint', 'cos', + 'cosh', 'cot', 'count', 'cross', 'curdate', 'current', 'curtime', + 'database', 'date', 'datediff', 'day', 'dayname', 'dayofmonth', + 'dayofweek', 'dayofyear', 'dec', 'decimal', 'decode', 'default', + 'degrees', 'delete', 'digits', 'distinct', 'double', 'except', + 'exists', 'exp', 'expand', 'first', 'fixed', 'float', 'floor', 'for', + 'from', 'full', 'get_objectname', 'get_schema', 'graphic', 'greatest', + 'group', 'having', 'hex', 'hextoraw', 'hour', 'ifnull', 'ignore', + 'index', 'initcap', 'inner', 'insert', 'int', 'integer', 'internal', + 'intersect', 'into', 'join', 'key', 'last', 'lcase', 'least', 'left', + 'length', 'lfill', 'list', 'ln', 'locate', 'log', 'log10', 'long', + 'longfile', 'lower', 'lpad', 'ltrim', 'makedate', 'maketime', + 'mapchar', 'max', 'mbcs', 'microsecond', 'min', 'minute', 'mod', + 'month', 'monthname', 'natural', 'nchar', 'next', 'no', 'noround', + 'not', 'now', 'null', 'num', 'numeric', 'object', 'of', 'on', + 'order', 'packed', 'pi', 'power', 'prev', 'primary', 'radians', + 'real', 'reject', 'relative', 'replace', 'rfill', 'right', 'round', + 'rowid', 'rowno', 'rpad', 'rtrim', 'second', 'select', 'selupd', + 'serial', 'set', 'show', 'sign', 'sin', 'sinh', 'smallint', 'some', + 'soundex', 'space', 'sqrt', 'stamp', 'statistics', 'stddev', + 'subdate', 'substr', 'substring', 'subtime', 'sum', 'sysdba', + 'table', 'tan', 'tanh', 'time', 'timediff', 'timestamp', 'timezone', + 'to', 'toidentifier', 'transaction', 'translate', 'trim', 'trunc', + 'truncate', 'ucase', 'uid', 'unicode', 'union', 'update', 'upper', + 'user', 'usergroup', 'using', 'utcdate', 'utcdiff', 'value', 'values', + 'varchar', 'vargraphic', 'variance', 'week', 'weekofyear', 'when', + 'where', 'with', 'year', 'zoned' ]) + + def _normalize_name(self, name): + if name is None: + return None + if name.isupper(): + lc_name = name.lower() + if not self._requires_quotes(lc_name): + return lc_name + return name + + def _denormalize_name(self, name): + if name is None: + return None + elif (name.islower() and + not self._requires_quotes(name)): + return name.upper() + else: + return name + + def _maybe_quote_identifier(self, name): + if self._requires_quotes(name): + return self.quote_identifier(name) + else: + return name + + +class MaxDBSchemaGenerator(compiler.SchemaGenerator): + def get_column_specification(self, column, **kw): + colspec = [self.preparer.format_column(column), + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()] + + if not column.nullable: + colspec.append('NOT NULL') + + default = column.default + default_str = self.get_column_default_string(column) + + # No DDL default for columns specified with non-optional sequence- + # this defaulting behavior is entirely client-side. (And as a + # consequence, non-reflectable.) + if (default and isinstance(default, schema.Sequence) and + not default.optional): + pass + # Regular default + elif default_str is not None: + colspec.append('DEFAULT %s' % default_str) + # Assign DEFAULT SERIAL heuristically + elif column.primary_key and column.autoincrement: + # For SERIAL on a non-primary key member, use + # PassiveDefault(text('SERIAL')) + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + (isinstance(c.type, sqltypes.Integer) or + (isinstance(c.type, MaxNumeric) and + c.type.precision)) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('DEFAULT SERIAL') + except IndexError: + pass + return ' '.join(colspec) + + def get_column_default_string(self, column): + if isinstance(column.default, schema.PassiveDefault): + if isinstance(column.default.arg, basestring): + if isinstance(column.type, sqltypes.Integer): + return str(column.default.arg) + else: + return "'%s'" % column.default.arg + else: + return unicode(self._compile(column.default.arg, None)) + else: + return None + + def visit_sequence(self, sequence): + """Creates a SEQUENCE. + + TODO: move to module doc? + + start + With an integer value, set the START WITH option. + + increment + An integer value to increment by. Default is the database default. + + maxdb_minvalue + maxdb_maxvalue + With an integer value, sets the corresponding sequence option. + + maxdb_no_minvalue + maxdb_no_maxvalue + Defaults to False. If true, sets the corresponding sequence option. + + maxdb_cycle + Defaults to False. If true, sets the CYCLE option. + + maxdb_cache + With an integer value, sets the CACHE option. + + maxdb_no_cache + Defaults to False. If true, sets NOCACHE. + """ + + if (not sequence.optional and + (not self.checkfirst or + not self.dialect.has_sequence(self.connection, sequence.name))): + + ddl = ['CREATE SEQUENCE', + self.preparer.format_sequence(sequence)] + + sequence.increment = 1 + + if sequence.increment is not None: + ddl.extend(('INCREMENT BY', str(sequence.increment))) + + if sequence.start is not None: + ddl.extend(('START WITH', str(sequence.start))) + + opts = dict([(pair[0][6:].lower(), pair[1]) + for pair in sequence.kwargs.items() + if pair[0].startswith('maxdb_')]) + + if 'maxvalue' in opts: + ddl.extend(('MAXVALUE', str(opts['maxvalue']))) + elif opts.get('no_maxvalue', False): + ddl.append('NOMAXVALUE') + if 'minvalue' in opts: + ddl.extend(('MINVALUE', str(opts['minvalue']))) + elif opts.get('no_minvalue', False): + ddl.append('NOMINVALUE') + + if opts.get('cycle', False): + ddl.append('CYCLE') + + if 'cache' in opts: + ddl.extend(('CACHE', str(opts['cache']))) + elif opts.get('no_cache', False): + ddl.append('NOCACHE') + + self.append(' '.join(ddl)) + self.execute() + + +class MaxDBSchemaDropper(compiler.SchemaDropper): + def visit_sequence(self, sequence): + if (not sequence.optional and + (not self.checkfirst or + self.dialect.has_sequence(self.connection, sequence.name))): + self.append("DROP SEQUENCE %s" % + self.preparer.format_sequence(sequence)) + self.execute() + + +def _autoserial_column(table): + """Finds the effective DEFAULT SERIAL column of a Table, if any.""" + + for index, col in enumerate(table.primary_key.columns): + if (isinstance(col.type, (sqltypes.Integer, sqltypes.Numeric)) and + col.autoincrement): + if isinstance(col.default, schema.Sequence): + if col.default.optional: + return index, col + elif (col.default is None or + (not isinstance(col.default, schema.PassiveDefault))): + return index, col + + return None, None + +def descriptor(): + return {'name': 'maxdb', + 'description': 'MaxDB', + 'arguments': [ + ('user', "Database Username", None), + ('password', "Database Password", None), + ('database', "Database Name", None), + ('host', "Hostname", None)]} + +dialect = MaxDBDialect +dialect.preparer = MaxDBIdentifierPreparer +dialect.statement_compiler = MaxDBCompiler +dialect.schemagenerator = MaxDBSchemaGenerator +dialect.schemadropper = MaxDBSchemaDropper +dialect.defaultrunner = MaxDBDefaultRunner diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 2062914041..ab5a968716 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -20,41 +20,57 @@ Note that the start & increment values for sequences are optional and will default to 1,1. -* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for +* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for ``INSERT`` s) * Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT`` * ``select._limit`` implemented as ``SELECT TOP n`` +* Experimental implemention of LIMIT / OFFSET with row_number() Known issues / TODO: * No support for more than one ``IDENTITY`` column per table -* No support for ``GUID`` type columns (yet) - * pymssql has problems with binary and unicode data that this module does **not** work around - + """ -import datetime, random, warnings +import datetime, operator, re, sys -from sqlalchemy import sql, schema, ansisql, exceptions -import sqlalchemy.types as sqltypes -from sqlalchemy.engine import default - -class MSNumeric(sqltypes.Numeric): - def convert_result_value(self, value, dialect): - return value +from sqlalchemy import sql, schema, exceptions, util +from sqlalchemy.sql import compiler, expression, operators as sqlops, functions as sql_functions +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes +from sqlalchemy.util import Decimal as _python_Decimal - def convert_bind_param(self, value, dialect): - if value is None: - # Not sure that this exception is needed - return value + +MSSQL_RESERVED_WORDS = util.Set(['function']) + +class MSNumeric(sqltypes.Numeric): + def result_processor(self, dialect): + if self.asdecimal: + def process(value): + if value is not None: + return _python_Decimal(str(value)) + else: + return value + return process else: - return str(value) + def process(value): + return float(value) + return process + + def bind_processor(self, dialect): + def process(value): + if value is None: + # Not sure that this exception is needed + return value + else: + return str(value) + return process def get_col_spec(self): if self.precision is None: @@ -66,21 +82,27 @@ class MSFloat(sqltypes.Float): def get_col_spec(self): return "FLOAT(%(precision)s)" % {'precision': self.precision} - def convert_bind_param(self, value, dialect): - """By converting to string, we can use Decimal types round-trip.""" - if not value is None: - return str(value) - return None + def bind_processor(self, dialect): + def process(value): + """By converting to string, we can use Decimal types round-trip.""" + if not value is None: + return str(value) + return None + return process class MSInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" -class MSTinyInteger(sqltypes.Integer): +class MSBigInteger(MSInteger): + def get_col_spec(self): + return "BIGINT" + +class MSTinyInteger(MSInteger): def get_col_spec(self): return "TINYINT" -class MSSmallInteger(sqltypes.Smallinteger): +class MSSmallInteger(MSInteger): def get_col_spec(self): return "SMALLINT" @@ -91,77 +113,93 @@ class MSDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATETIME" -class MSDate(sqltypes.Date): +class MSSmallDate(sqltypes.Date): def __init__(self, *a, **kw): super(MSDate, self).__init__(False) def get_col_spec(self): return "SMALLDATETIME" + def result_processor(self, dialect): + def process(value): + # If the DBAPI returns the value as datetime.datetime(), truncate it back to datetime.date() + if type(value) is datetime.datetime: + return value.date() + return value + return process + +class MSDate(sqltypes.Date): + def __init__(self, *a, **kw): + super(MSDate, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + + def result_processor(self, dialect): + def process(value): + # If the DBAPI returns the value as datetime.datetime(), truncate it back to datetime.date() + if type(value) is datetime.datetime: + return value.date() + return value + return process + class MSTime(sqltypes.Time): __zero_date = datetime.date(1900, 1, 1) def __init__(self, *a, **kw): super(MSTime, self).__init__(False) - + def get_col_spec(self): return "DATETIME" - def convert_bind_param(self, value, dialect): - if isinstance(value, datetime.datetime): - value = datetime.datetime.combine(self.__zero_date, value.time()) - elif isinstance(value, datetime.time): - value = datetime.datetime.combine(self.__zero_date, value) - return value - - def convert_result_value(self, value, dialect): - if isinstance(value, datetime.datetime): - return value.time() - elif isinstance(value, datetime.date): - return datetime.time(0, 0, 0) - return value + def bind_processor(self, dialect): + def process(value): + if type(value) is datetime.datetime: + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif type(value) is datetime.time: + value = datetime.datetime.combine(self.__zero_date, value) + return value + return process + + def result_processor(self, dialect): + def process(value): + if type(value) is datetime.datetime: + return value.time() + elif type(value) is datetime.date: + return datetime.time(0, 0, 0) + return value + return process class MSDateTime_adodbapi(MSDateTime): - def convert_result_value(self, value, dialect): - # adodbapi will return datetimes with empty time values as datetime.date() objects. - # Promote them back to full datetime.datetime() - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - return value + def result_processor(self, dialect): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) + return value + return process class MSDateTime_pyodbc(MSDateTime): - def convert_bind_param(self, value, dialect): - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - else: + def bind_processor(self, dialect): + def process(value): + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) return value + return process class MSDate_pyodbc(MSDate): - def convert_bind_param(self, value, dialect): - if value and not hasattr(value, 'second'): - return datetime.datetime(value.year, value.month, value.day) - else: - return value - - def convert_result_value(self, value, dialect): - # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() - if value and hasattr(value, 'second'): - return value.date() - else: - return value - -class MSDate_pymssql(MSDate): - def convert_result_value(self, value, dialect): - # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() - if value and hasattr(value, 'second'): - return value.date() - else: + def bind_processor(self, dialect): + def process(value): + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) return value + return process -class MSText(sqltypes.TEXT): +class MSText(sqltypes.Text): def get_col_spec(self): if self.dialect.text_as_varchar: - return "VARCHAR(max)" + return "VARCHAR(max)" else: return "TEXT" @@ -180,11 +218,11 @@ class MSNVarchar(sqltypes.Unicode): class AdoMSNVarchar(MSNVarchar): """overrides bindparam/result processing to not convert any unicode strings""" - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - return value + def result_processor(self, dialect): + return None class MSChar(sqltypes.CHAR): def get_col_spec(self): @@ -202,25 +240,45 @@ class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BIT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + class MSTimeStamp(sqltypes.TIMESTAMP): def get_col_spec(self): return "TIMESTAMP" - + +class MSMoney(sqltypes.TypeEngine): + def get_col_spec(self): + return "MONEY" + +class MSSmallMoney(MSMoney): + def get_col_spec(self): + return "SMALLMONEY" + +class MSUniqueIdentifier(sqltypes.TypeEngine): + def get_col_spec(self): + return "UNIQUEIDENTIFIER" + +class MSVariant(sqltypes.TypeEngine): + def get_col_spec(self): + return "SQL_VARIANT" + def descriptor(): return {'name':'mssql', 'description':'MSSQL', @@ -238,7 +296,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): def _has_implicit_sequence(self, column): if column.primary_key and column.autoincrement: - if isinstance(column.type, sqltypes.Integer) and not column.foreign_key: + if isinstance(column.type, sqltypes.Integer) and not column.foreign_keys: if column.default is None or (isinstance(column.default, schema.Sequence) and \ column.default.optional): return True @@ -247,7 +305,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): def pre_exec(self): """MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. - + Activate it if the feature is turned on and needed. """ if self.compiled.isinsert: @@ -269,7 +327,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): self.IINSERT = False if self.IINSERT: - self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.dialect.preparer().format_table(self.compiled.statement.table)) + self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) super(MSSQLExecutionContext, self).pre_exec() @@ -278,48 +336,40 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): and fetch recently inserted IDENTIFY values (works only for one column). """ - - if self.compiled.isinsert: - if self.IINSERT: - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.preparer().format_table(self.compiled.statement.table)) - self.IINSERT = False - elif self.HASIDENT: - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - if self.dialect.use_scope_identity: - self.cursor.execute("SELECT scope_identity() AS lastrowid") - else: - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] - # print "LAST ROW ID", self._last_inserted_ids - self.HASIDENT = False + + if self.compiled.isinsert and (not self.executemany) and self.HASIDENT and not self.IINSERT: + if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + if self.dialect.use_scope_identity: + self.cursor.execute("SELECT scope_identity() AS lastrowid") + else: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] super(MSSQLExecutionContext, self).post_exec() + _ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns|EXEC)', + re.I | re.UNICODE) + + def returns_rows_text(self, statement): + return self._ms_is_select.match(statement) is not None -class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): + +class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): def pre_exec(self): - """execute "set nocount on" on all connections, as a partial - workaround for multiple result set issues.""" - - if not getattr(self.connection, 'pyodbc_done_nocount', False): - self.connection.execute('SET nocount ON') - self.connection.pyodbc_done_nocount = True - + """where appropriate, issue "select scope_identity()" in the same statement""" super(MSSQLExecutionContext_pyodbc, self).pre_exec() - - # where appropriate, issue "select scope_identity()" in the same statement - if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity: + if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) \ + and len(self.parameters) == 1 and self.dialect.use_scope_identity: self.statement += "; select scope_identity()" def post_exec(self): if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity: # do nothing - id was fetched in dialect.do_execute() - self.HASIDENT = False + pass else: super(MSSQLExecutionContext_pyodbc, self).post_exec() - -class MSSQLDialect(ansisql.ANSIDialect): +class MSSQLDialect(default.DefaultDialect): colspecs = { sqltypes.Unicode : MSNVarchar, sqltypes.Integer : MSInteger, @@ -332,7 +382,7 @@ class MSSQLDialect(ansisql.ANSIDialect): sqltypes.String : MSString, sqltypes.Binary : MSBinary, sqltypes.Boolean : MSBoolean, - sqltypes.TEXT : MSText, + sqltypes.Text : MSText, sqltypes.CHAR: MSChar, sqltypes.NCHAR: MSNChar, sqltypes.TIMESTAMP: MSTimeStamp, @@ -340,6 +390,7 @@ class MSSQLDialect(ansisql.ANSIDialect): ischema_names = { 'int' : MSInteger, + 'bigint': MSBigInteger, 'smallint' : MSSmallInteger, 'tinyint' : MSTinyInteger, 'varchar' : MSString, @@ -352,12 +403,18 @@ class MSSQLDialect(ansisql.ANSIDialect): 'numeric' : MSNumeric, 'float' : MSFloat, 'datetime' : MSDateTime, - 'smalldatetime' : MSDate, + 'date': MSDate, + 'smalldatetime' : MSSmallDate, 'binary' : MSBinary, + 'varbinary' : MSBinary, 'bit': MSBoolean, 'real' : MSFloat, 'image' : MSBinary, 'timestamp': MSTimeStamp, + 'money': MSMoney, + 'smallmoney': MSSmallMoney, + 'uniqueidentifier': MSUniqueIdentifier, + 'sql_variant': MSVariant, } def __new__(cls, dbapi=None, *args, **kwargs): @@ -368,12 +425,13 @@ class MSSQLDialect(ansisql.ANSIDialect): return dialect(*args, **kwargs) else: return object.__new__(cls, *args, **kwargs) - + def __init__(self, auto_identity_insert=True, **params): super(MSSQLDialect, self).__init__(**params) self.auto_identity_insert = auto_identity_insert self.text_as_varchar = False self.use_scope_identity = False + self.has_window_funcs = False self.set_default_schema_name("dbo") def dbapi(cls, module_name=None): @@ -392,18 +450,20 @@ class MSSQLDialect(ansisql.ANSIDialect): else: raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi') dbapi = classmethod(dbapi) - + def create_connect_args(self, url): - opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) + opts = url.translate_connect_args(username='user') opts.update(url.query) - if opts.has_key('auto_identity_insert'): + if 'auto_identity_insert' in opts: self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert'))) - if opts.has_key('query_timeout'): + if 'query_timeout' in opts: self.query_timeout = int(opts.pop('query_timeout')) - if opts.has_key('text_as_varchar'): + if 'text_as_varchar' in opts: self.text_as_varchar = bool(int(opts.pop('text_as_varchar'))) - if opts.has_key('use_scope_identity'): + if 'use_scope_identity' in opts: self.use_scope_identity = bool(int(opts.pop('use_scope_identity'))) + if 'has_window_funcs' in opts: + self.has_window_funcs = bool(int(opts.pop('has_window_funcs'))) return self.make_connect_string(opts) def create_execution_context(self, *args, **kwargs): @@ -419,26 +479,7 @@ class MSSQLDialect(ansisql.ANSIDialect): def last_inserted_ids(self): return self.context.last_inserted_ids - # this is only implemented in the dbapi-specific subclasses - def supports_sane_rowcount(self): - raise NotImplementedError() - - def compiler(self, statement, bindparams, **kwargs): - return MSSQLCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return MSSQLSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return MSSQLSchemaDropper(self, *args, **kwargs) - - def defaultrunner(self, connection, **kwargs): - return MSSQLDefaultRunner(connection, **kwargs) - - def preparer(self): - return MSSQLIdentifierPreparer(self) - - def get_default_schema_name(self): + def get_default_schema_name(self, connection): return self.schema_name def set_default_schema_name(self, schema_name): @@ -446,11 +487,22 @@ class MSSQLDialect(ansisql.ANSIDialect): def last_inserted_ids(self): return self.context.last_inserted_ids - - def do_execute(self, cursor, statement, params, **kwargs): + + def do_execute(self, cursor, statement, params, context=None, **kwargs): if params == {}: params = () - super(MSSQLDialect, self).do_execute(cursor, statement, params, **kwargs) + try: + super(MSSQLDialect, self).do_execute(cursor, statement, params, context=context, **kwargs) + finally: + if context.IINSERT: + cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table)) + + def do_executemany(self, cursor, statement, params, context=None, **kwargs): + try: + super(MSSQLDialect, self).do_executemany(cursor, statement, params, context=context, **kwargs) + finally: + if context.IINSERT: + cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table)) def _execute(self, c, statement, parameters): try: @@ -460,12 +512,16 @@ class MSSQLDialect(ansisql.ANSIDialect): self.context.rowcount = c.rowcount c.DBPROP_COMMITPRESERVE = "Y" except Exception, e: - raise exceptions.SQLError(statement, parameters, e) + raise exceptions.DBAPIError.instance(statement, parameters, e) + + def table_names(self, connection, schema): + from sqlalchemy.databases import information_schema as ischema + return ischema.table_names(connection, schema) def raw_connection(self, connection): """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes""" try: - # TODO: probably want to move this to individual dialect subclasses to + # TODO: probably want to move this to individual dialect subclasses to # save on the exception throw + simplify return connection.connection.__dict__['_pymssqlCnx__cnx'] except: @@ -483,26 +539,26 @@ class MSSQLDialect(ansisql.ANSIDialect): def has_table(self, connection, tablename, schema=None): import sqlalchemy.databases.information_schema as ischema - current_schema = schema or self.get_default_schema_name() + current_schema = schema or self.get_default_schema_name(connection) columns = self.uppercase_table(ischema.columns) s = sql.select([columns], current_schema and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) or columns.c.table_name==tablename, ) - + c = connection.execute(s) row = c.fetchone() return row is not None - + def reflecttable(self, connection, table, include_columns): import sqlalchemy.databases.information_schema as ischema - + # Get base columns if table.schema is not None: current_schema = table.schema else: - current_schema = self.get_default_schema_name() + current_schema = self.get_default_schema_name(connection) columns = self.uppercase_table(ischema.columns) s = sql.select([columns], @@ -510,7 +566,7 @@ class MSSQLDialect(ansisql.ANSIDialect): and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) or columns.c.table_name==table.name, order_by=[columns.c.ordinal_position]) - + c = connection.execute(s) found_table = False while True: @@ -519,9 +575,9 @@ class MSSQLDialect(ansisql.ANSIDialect): break found_table = True (name, type, nullable, charlen, numericprec, numericscale, default) = ( - row[columns.c.column_name], - row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', + row[columns.c.column_name], + row[columns.c.data_type], + row[columns.c.is_nullable] == 'YES', row[columns.c.character_maximum_length], row[columns.c.numeric_precision], row[columns.c.numeric_scale], @@ -536,26 +592,27 @@ class MSSQLDialect(ansisql.ANSIDialect): args.append(a) coltype = self.ischema_names.get(type, None) if coltype == MSString and charlen == -1: - coltype = MSText() + coltype = MSText() else: if coltype is None: - warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name))) + util.warn("Did not recognize type '%s' of column '%s'" % + (type, name)) coltype = sqltypes.NULLTYPE - - elif coltype == MSNVarchar and charlen == -1: - charlen = None + + elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1: + args[0] = None coltype = coltype(*args) colargs= [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) - + + table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs)) + if not found_table: raise exceptions.NoSuchTableError(table.name) # We also run an sp_columns to check for identity columns: - cursor = connection.execute("sp_columns " + self.preparer().format_table(table)) + cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema)) ic = None while True: row = cursor.fetchone() @@ -564,6 +621,7 @@ class MSSQLDialect(ansisql.ANSIDialect): col_name, type_name = row[3], row[5] if type_name.endswith("identity"): ic = table.c[col_name] + ic.autoincrement = True # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute ic.sequence = schema.Sequence(ic.name + '_identity') # MSSQL: only one identity per table allowed @@ -584,7 +642,7 @@ class MSSQLDialect(ansisql.ANSIDialect): # Add constraints RR = self.uppercase_table(ischema.ref_constraints) #information_schema.referential_constraints TC = self.uppercase_table(ischema.constraints) #information_schema.table_constraints - C = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column + C = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column R = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column # Primary key constraints @@ -600,7 +658,7 @@ class MSSQLDialect(ansisql.ANSIDialect): R.c.table_schema, R.c.table_name, R.c.column_name, RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule], sql.and_(C.c.table_name == table.name, - C.c.table_schema == current_schema, + C.c.table_schema == (table.schema or current_schema), C.c.constraint_name == RR.c.constraint_name, R.c.constraint_name == RR.c.unique_constraint_name, C.c.ordinal_position == R.c.ordinal_position @@ -608,21 +666,37 @@ class MSSQLDialect(ansisql.ANSIDialect): order_by = [RR.c.constraint_name, R.c.ordinal_position]) rows = connection.execute(s).fetchall() + def _gen_fkref(table, rschema, rtbl, rcol): + if table.schema and rschema != table.schema or rschema != current_schema: + return '.'.join([rschema, rtbl, rcol]) + else: + return '.'.join([rtbl, rcol]) + # group rows by constraint ID, to handle multi-column FKs fknm, scols, rcols = (None, [], []) for r in rows: scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r + + if table.schema and rschema != table.schema or rschema != current_schema: + schema.Table(rtbl, table.metadata, schema=rschema, autoload=True, autoload_with=connection) + else: + schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection) + if rfknm != fknm: if fknm: - table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) + table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table,s,t,c) for s,t,c in rcols], fknm)) fknm, scols, rcols = (rfknm, [], []) if (not scol in scols): scols.append(scol) if (not (rschema, rtbl, rcol) in rcols): rcols.append((rschema, rtbl, rcol)) if fknm and scols: - table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) + table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table,s,t,c) for s,t,c in rcols], fknm)) + class MSSQLDialect_pymssql(MSSQLDialect): + supports_sane_rowcount = False + max_identifier_length = 30 + def import_dbapi(cls): import pymssql as module # pymmsql doesn't have a Binary method. we use string @@ -630,22 +704,17 @@ class MSSQLDialect_pymssql(MSSQLDialect): module.Binary = lambda st: str(st) return module import_dbapi = classmethod(import_dbapi) - - colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Date] = MSDate_pymssql ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['smalldatetime'] = MSDate_pymssql + def __init__(self, **params): super(MSSQLDialect_pymssql, self).__init__(**params) self.use_scope_identity = True - def supports_sane_rowcount(self): - return False - - def max_identifier_length(self): - return 30 + # pymssql understands only ascii + if self.convert_unicode: + self.encoding = params.get('encoding', 'ascii') def do_rollback(self, connection): # pymssql throws an error on repeated rollbacks. Ignore it. @@ -672,74 +741,72 @@ class MSSQLDialect_pymssql(MSSQLDialect): return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) -## This code is leftover from the initial implementation, for reference -## def do_begin(self, connection): -## """implementations might want to put logic here for turning autocommit on/off, etc.""" -## pass - -## def do_rollback(self, connection): -## """implementations might want to put logic here for turning autocommit on/off, etc.""" -## try: -## # connection.rollback() for pymmsql failed sometimes--the begin tran doesn't show up -## # this is a workaround that seems to be handle it. -## r = self.raw_connection(connection) -## r.query("if @@trancount > 0 rollback tran") -## r.fetch_array() -## r.query("begin tran") -## r.fetch_array() -## except: -## pass - -## def do_commit(self, connection): -## """implementations might want to put logic here for turning autocommit on/off, etc. -## do_commit is set for pymmsql connections--ADO seems to handle transactions without any issue -## """ -## # ADO Uses Implicit Transactions. -## # This is very pymssql specific. We use this instead of its commit, because it hangs on failed rollbacks. -## # By using the "if" we don't assume an open transaction--much better. -## r = self.raw_connection(connection) -## r.query("if @@trancount > 0 commit tran") -## r.fetch_array() -## r.query("begin tran") -## r.fetch_array() - class MSSQLDialect_pyodbc(MSSQLDialect): - + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + # PyODBC unicode is broken on UCS-4 builds + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = supports_unicode + + def __init__(self, **params): + super(MSSQLDialect_pyodbc, self).__init__(**params) + # whether use_scope_identity will work depends on the version of pyodbc + try: + import pyodbc + self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset') + except: + pass + def import_dbapi(cls): import pyodbc as module return module import_dbapi = classmethod(import_dbapi) - + colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Unicode] = AdoMSNVarchar + if supports_unicode: + colspecs[sqltypes.Unicode] = AdoMSNVarchar colspecs[sqltypes.Date] = MSDate_pyodbc colspecs[sqltypes.DateTime] = MSDateTime_pyodbc ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['nvarchar'] = AdoMSNVarchar + if supports_unicode: + ischema_names['nvarchar'] = AdoMSNVarchar ischema_names['smalldatetime'] = MSDate_pyodbc ischema_names['datetime'] = MSDateTime_pyodbc - def supports_sane_rowcount(self): - return False - - def supports_unicode_statements(self): - """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" - return True - def make_connect_string(self, keys): - connectors = ["Driver={SQL Server}"] - if 'port' in keys: - connectors.append('Server=%s,%d' % (keys.get('host'), keys.get('port'))) + if 'max_identifier_length' in keys: + self.max_identifier_length = int(keys.pop('max_identifier_length')) + if 'dsn' in keys: + connectors = ['dsn=%s' % keys['dsn']] else: - connectors.append('Server=%s' % keys.get('host')) - connectors.append("Database=%s" % keys.get("database")) + connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'), + 'Server=%s' % keys['host'], + 'Database=%s' % keys['database'] ] + if 'port' in keys: + connectors.append('Port=%d' % int(keys['port'])) + user = keys.get("user") if user: connectors.append("UID=%s" % user) connectors.append("PWD=%s" % keys.get("password", "")) else: - connectors.append ("TrustedConnection=Yes") + connectors.append("TrustedConnection=Yes") + + # if set to 'Yes', the ODBC layer will try to automagically convert + # textual data from your database encoding to your client encoding + # This should obviously be set to 'No' if you query a cp1253 encoded + # database from a latin1 client... + if 'odbc_autotranslate' in keys: + connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) + + # Allow specification of partial ODBC connect string + if 'odbc_options' in keys: + odbc_options=keys.pop('odbc_options') + if odbc_options[0]=="'" and odbc_options[-1]=="'": + odbc_options=odbc_options[1:-1] + connectors.append(odbc_options) + return [[";".join (connectors)], {}] def is_disconnect(self, e): @@ -752,17 +819,22 @@ class MSSQLDialect_pyodbc(MSSQLDialect): super(MSSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) if context and context.HASIDENT and (not context.IINSERT) and context.dialect.use_scope_identity: import pyodbc - # fetch the last inserted id from the manipulated statement (pre_exec). - try: - row = cursor.fetchone() - except pyodbc.Error, e: - # if nocount OFF fetchone throws an exception and we have to jump over - # the rowcount to the resultset - cursor.nextset() - row = cursor.fetchone() + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with no data (due to triggers, etc.) + while True: + try: + row = cursor.fetchone() + break + except pyodbc.Error, e: + cursor.nextset() context._last_inserted_ids = [int(row[0])] class MSSQLDialect_adodbapi(MSSQLDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = True + def import_dbapi(cls): import adodbapi as module return module @@ -776,13 +848,6 @@ class MSSQLDialect_adodbapi(MSSQLDialect): ischema_names['nvarchar'] = AdoMSNVarchar ischema_names['datetime'] = MSDateTime_adodbapi - def supports_sane_rowcount(self): - return True - - def supports_unicode_statements(self): - """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" - return True - def make_connect_string(self, keys): connectors = ["Provider=SQLOLEDB"] if 'port' in keys: @@ -808,32 +873,71 @@ dialect_mapping = { } -class MSSQLCompiler(ansisql.ANSICompiler): - def __init__(self, dialect, statement, parameters, **kwargs): - super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) +class MSSQLCompiler(compiler.DefaultCompiler): + operators = compiler.OPERATORS.copy() + operators[sqlops.concat_op] = '+' + + functions = compiler.DefaultCompiler.functions.copy() + functions.update ( + { + sql_functions.now: 'CURRENT_TIMESTAMP' + } + ) + + def __init__(self, *args, **kwargs): + super(MSSQLCompiler, self).__init__(*args, **kwargs) self.tablealiases = {} def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ - s = select._distinct and "DISTINCT " or "" - if select._limit: - s += "TOP %s " % (select._limit,) - if select._offset: - raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset') - return s - - def limit_clause(self, select): + if not self.dialect.has_window_funcs: + s = select._distinct and "DISTINCT " or "" + if select._limit: + s += "TOP %s " % (select._limit,) + if select._offset: + raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset') + return s + return compiler.DefaultCompiler.get_select_precolumns(self, select) + + def limit_clause(self, select): # Limit in mssql is after the select keyword return "" - + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + """ + if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._limit is not None or select._offset is not None): + # to use ROW_NUMBER(), an ORDER BY is required. + orderby = self.process(select._order_by_clause) + if not orderby: + orderby = list(select.oid_column.proxies)[0] + orderby = self.process(orderby) + + _offset = select._offset + _limit = select._limit + select._mssql_visit = True + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() + + limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) + if _offset is not None: + limitselect.append_whereclause("mssql_rn>=%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) + else: + limitselect.append_whereclause("mssql_rn<=%d" % _limit) + return self.process(limitselect, iswrapper=True, **kwargs) + else: + return compiler.DefaultCompiler.visit_select(self, select, **kwargs) + def _schema_aliased_table(self, table): if getattr(table, 'schema', None) is not None: - if not self.tablealiases.has_key(table): + if table not in self.tablealiases: self.tablealiases[table] = table.alias() return self.tablealiases[table] else: return None - + def visit_table(self, table, mssql_aliased=False, **kwargs): if mssql_aliased: return super(MSSQLCompiler, self).visit_table(table, **kwargs) @@ -844,39 +948,46 @@ class MSSQLCompiler(ansisql.ANSICompiler): return self.process(alias, mssql_aliased=True, **kwargs) else: return super(MSSQLCompiler, self).visit_table(table, **kwargs) - + def visit_alias(self, alias, **kwargs): # translate for schema-qualified table aliases self.tablealiases[alias.original] = alias + kwargs['mssql_aliased'] = True return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) - def visit_column(self, column): - if column.table is not None: + def visit_column(self, column, result_map=None, **kwargs): + if column.table is not None and not self.isupdate and not self.isdelete: # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: - return self.process(t.corresponding_column(column)) - return super(MSSQLCompiler, self).visit_column(column) + converted = expression._corresponding_column_or_error(t, column) + + if result_map is not None: + result_map[column.name.lower()] = (column.name, (column, ), column.type) + + return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) + + return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) - def visit_binary(self, binary): + def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq: - return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator)) + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) else: - return super(MSSQLCompiler, self).visit_binary(binary) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) - def label_select_column(self, select, column): - if isinstance(column, sql._Function): - return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression._Function): + return column.label(None) else: - return super(MSSQLCompiler, self).label_select_column(select, column) + return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) function_rewrites = {'current_date': 'getdate', 'length': 'len', } - def visit_function(self, func): + def visit_function(self, func, **kwargs): func.name = self.function_rewrites.get(func.name, func.name) - super(MSSQLCompiler, self).visit_function(func) + return super(MSSQLCompiler, self).visit_function(func, **kwargs) def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use @@ -892,13 +1003,13 @@ class MSSQLCompiler(ansisql.ANSICompiler): return "" -class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): +class MSSQLSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() - + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + # install a IDENTITY Sequence if we have an implicit IDENTITY column if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ - column.autoincrement and isinstance(column.type, sqltypes.Integer) and not column.foreign_key: + column.autoincrement and isinstance(column.type, sqltypes.Integer) and not column.foreign_keys: if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): column.sequence = schema.Sequence(column.name + '_seq') @@ -912,10 +1023,10 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default - + return colspec -class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): +class MSSQLSchemaDropper(compiler.SchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s.%s" % ( self.preparer.quote_identifier(index.table.name), @@ -924,11 +1035,14 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): self.execute() -class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): +class MSSQLDefaultRunner(base.DefaultRunner): # TODO: does ms-sql have standalone sequences ? + # A: No, only auto-incrementing IDENTITY property of a column pass -class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class MSSQLIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS) + def __init__(self, dialect): super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') @@ -936,11 +1050,9 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): #TODO: determin MSSQL's escapeing rules return value - def _fold_identifier_case(self, value): - #TODO: determin MSSQL's case folding rules - return value - dialect = MSSQLDialect - - - +dialect.statement_compiler = MSSQLCompiler +dialect.schemagenerator = MSSQLSchemaGenerator +dialect.schemadropper = MSSQLSchemaDropper +dialect.preparer = MSSQLIdentifierPreparer +dialect.defaultrunner = MSSQLDefaultRunner diff --git a/lib/sqlalchemy/databases/mxODBC.py b/lib/sqlalchemy/databases/mxODBC.py new file mode 100644 index 0000000000..a3acac5870 --- /dev/null +++ b/lib/sqlalchemy/databases/mxODBC.py @@ -0,0 +1,60 @@ +# mxODBC.py +# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch +# Coding: Alexander Houben alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +A wrapper for a mx.ODBC.Windows DB-API connection. + +Makes sure the mx module is configured to return datetime objects instead +of mx.DateTime.DateTime objects. +""" + +from mx.ODBC.Windows import * + + +class Cursor: + def __init__(self, cursor): + self.cursor = cursor + + def __getattr__(self, attr): + res = getattr(self.cursor, attr) + return res + + def execute(self, *args, **kwargs): + res = self.cursor.execute(*args, **kwargs) + return res + + +class Connection: + def myErrorHandler(self, connection, cursor, errorclass, errorvalue): + err0, err1, err2, err3 = errorvalue + #print ", ".join(["Err%d: %s"%(x, errorvalue[x]) for x in range(4)]) + if int(err1) == 109: + # Ignore "Null value eliminated in aggregate function", this is not an error + return + raise errorclass, errorvalue + + def __init__(self, conn): + self.conn = conn + # install a mx ODBC error handler + self.conn.errorhandler = self.myErrorHandler + + def __getattr__(self, attr): + res = getattr(self.conn, attr) + return res + + def cursor(self, *args, **kwargs): + res = Cursor(self.conn.cursor(*args, **kwargs)) + return res + + +# override 'connect' call +def connect(*args, **kwargs): + import mx.ODBC.Windows + conn = mx.ODBC.Windows.Connect(*args, **kwargs) + conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT + return Connection(conn) +Connect = connect diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 26800e32b3..a86035be50 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1,23 +1,179 @@ # mysql.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, datetime, inspect, warnings, weakref, operator +"""Support for the MySQL database. + +SQLAlchemy supports 6 major MySQL versions: 3.23, 4.0, 4.1, 5.0, 5.1 and 6.0, +with capablities increasing with more modern servers. + +Versions 4.1 and higher support the basic SQL functionality that SQLAlchemy +uses in the ORM and SQL expressions. These versions pass the applicable +tests in the suite 100%. No heroic measures are taken to work around major +missing SQL features- if your server version does not support sub-selects, for +example, they won't work in SQLAlchemy either. + +Currently, the only DB-API driver supported is `MySQL-Python` (also referred to +as `MySQLdb`). Either 1.2.1 or 1.2.2 are recommended. The alpha, beta and +gamma releases of 1.2.1 and 1.2.2 should be avoided. Support for Jython and +IronPython is planned. + +===================================== =============== +Feature Minimum Version +===================================== =============== +sqlalchemy.orm 4.1.1 +Table Reflection 3.23.x +DDL Generation 4.1.1 +utf8/Full Unicode Connections 4.1.1 +Transactions 3.23.15 +Two-Phase Transactions 5.0.3 +Nested Transactions 5.0.3 +===================================== =============== + +See the official MySQL documentation for detailed information about features +supported in any given server release. + +Many MySQL server installations default to a ``latin1`` encoding for client +connections. All data sent through the connection will be converted +into ``latin1``, even if you have ``utf8`` or another character set on your +tables and columns. With versions 4.1 and higher, you can change the +connection character set either through server configuration or by passing +the ``charset`` parameter to ``create_engine``. The ``charset`` option is +passed through to MySQL-Python and has the side-effect of also enabling +``use_unicode`` in the driver by default. For regular encoded strings, also +pass ``use_unicode=0`` in the connection arguments. + +Most MySQL server installations have a default table type of `MyISAM`, a +non-transactional table type. During a transaction, non-transactional +storage engines do not participate and continue to store table changes in +autocommit mode. For fully atomic transactions, all participating tables +must use a transactional engine such as `InnoDB`, `Falcon`, `SolidDB`, +`PBXT`, etc. Storage engines can be elected when creating tables in +SQLAlchemy by supplying a ``mysql_engine='whatever'`` to the ``Table`` +constructor. Any MySQL table creation option can be specified in this syntax. + +Not all MySQL storage engines support foreign keys. For `MyISAM` and similar +engines, the information loaded by table reflection will not include foreign +keys. For these tables, you may supply ``ForeignKeyConstraints`` at reflection +time:: + + Table('mytable', metadata, autoload=True, + ForeignKeyConstraint(['other_id'], ['othertable.other_id'])) + +When creating tables, SQLAlchemy will automatically set AUTO_INCREMENT on an +integer primary key column:: + + >>> t = Table('mytable', metadata, + ... Column('mytable_id', Integer, primary_key=True)) + >>> t.create() + CREATE TABLE mytable ( + id INTEGER NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + ) + +You can disable this behavior by supplying ``autoincrement=False`` in addition. +This can also be used to enable auto-increment on a secondary column in a +multi-column key for some storage engines:: + + Table('mytable', metadata, + Column('gid', Integer, primary_key=True, autoincrement=False), + Column('id', Integer, primary_key=True)) + +MySQL SQL modes are supported. Modes that enable ``ANSI_QUOTES`` (such as +``ANSI``) require an engine option to modify SQLAlchemy's quoting style. +When using an ANSI-quoting mode, supply ``use_ansiquotes=True`` when +creating your ``Engine``:: + + create_engine('mysql://localhost/test', use_ansiquotes=True) + +This is an engine-wide option and is not toggleable on a per-connection basis. +SQLAlchemy does not presume to ``SET sql_mode`` for you with this option. +For the best performance, set the quoting style server-wide in ``my.cnf`` or +by supplying ``--sql-mode`` to ``mysqld``. You can also use a ``Pool`` hook +to issue a ``SET SESSION sql_mode='...'`` on connect to configure each +connection. + +If you do not specify 'use_ansiquotes', the regular MySQL quoting style is +used by default. Table reflection operations will query the server + +If you do issue a 'SET sql_mode' through SQLAlchemy, the dialect must be +updated if the quoting style is changed. Again, this change will affect all +connections:: + + connection.execute('SET sql_mode="ansi"') + connection.dialect.use_ansiquotes = True + +For normal SQLAlchemy usage, loading this module is unnescesary. It will be +loaded on-demand when a MySQL connection is needed. The generic column types +like ``String`` and ``Integer`` will automatically be adapted to the optimal +matching MySQL column type. + +But if you would like to use one of the MySQL-specific or enhanced column +types when creating tables with your ``Table`` definitions, then you will +need to import them from this module:: + + from sqlalchemy.databases import mysql + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('ittybittyblob', mysql.MSTinyBlob), + Column('biggy', mysql.MSBigInteger(unsigned=True))) + +All standard MySQL column types are supported. The OpenGIS types are +available for use via table reflection but have no special support or +mapping to Python classes. If you're using these types and have opinions +about how OpenGIS can be smartly integrated into SQLAlchemy please join +the mailing list! + +Many of the MySQL SQL extensions are handled through SQLAlchemy's generic +function and operator support:: + + table.select(table.c.password==func.md5('plaintext')) + table.select(table.c.username.op('regexp')('^[a-d]')) + +And of course any valid statement can be executed as a string rather than +through the SQL expression language. + +Some limited support for MySQL extensions to SQL expressions is currently +available. + + * SELECT pragma:: -from sqlalchemy import sql, schema, ansisql -from sqlalchemy.engine import default -import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions -import sqlalchemy.util as util + select(..., prefixes=['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) + + * UPDATE with LIMIT:: + + update(..., mysql_limit=10) + +If you have problems that seem server related, first check that you are +using the most recent stable MySQL-Python package available. The Database +Notes page on the wiki at http://www.sqlalchemy.org is a good resource for +timely information affecting MySQL in SQLAlchemy. +""" + +import datetime, inspect, re, sys from array import array as _array -from decimal import Decimal -try: - from threading import Lock -except ImportError: - from dummy_threading import Lock +from sqlalchemy import exceptions, logging, schema, sql, util +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy.sql import functions as sql_functions +from sqlalchemy.sql import compiler + +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import types as sqltypes + + +__all__ = ( + 'MSBigInteger', 'MSBinary', 'MSBit', 'MSBlob', 'MSBoolean', + 'MSChar', 'MSDate', 'MSDateTime', 'MSDecimal', 'MSDouble', + 'MSEnum', 'MSFloat', 'MSInteger', 'MSLongBlob', 'MSLongText', + 'MSMediumBlob', 'MSMediumText', 'MSNChar', 'MSNVarChar', + 'MSNumeric', 'MSSet', 'MSSmallInteger', 'MSString', 'MSText', + 'MSTime', 'MSTimeStamp', 'MSTinyBlob', 'MSTinyInteger', + 'MSTinyText', 'MSVarBinary', 'MSYear' ) + RESERVED_WORDS = util.Set( ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', @@ -60,10 +216,20 @@ RESERVED_WORDS = util.Set( 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', 'read_only', 'read_write', # 5.1 ]) -_per_connection_mutex = Lock() + +AUTOCOMMIT_RE = re.compile( + r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', + re.I | re.UNICODE) +SELECT_RE = re.compile( + r'\s*(?:SELECT|SHOW|DESCRIBE|XA RECOVER)', + re.I | re.UNICODE) +SET_RE = re.compile( + r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', + re.I | re.UNICODE) + class _NumericType(object): - "Base for MySQL numeric types." + """Base for MySQL numeric types.""" def __init__(self, unsigned=False, zerofill=False, **kw): self.unsigned = unsigned @@ -71,21 +237,22 @@ class _NumericType(object): def _extend(self, spec): "Extend a numeric-type declaration with MySQL specific extensions." - + if self.unsigned: spec += ' UNSIGNED' if self.zerofill: spec += ' ZEROFILL' return spec + class _StringType(object): - "Base for MySQL string types." + """Base for MySQL string types.""" def __init__(self, charset=None, collation=None, ascii=False, unicode=False, binary=False, national=False, **kwargs): self.charset = charset - # allow collate= or collation= + # allow collate= or collation= self.collation = kwargs.get('collate', collation) self.ascii = ascii self.unicode = unicode @@ -96,7 +263,7 @@ class _StringType(object): """Extend a string-type declaration with standard SQL CHARACTER SET / COLLATE annotations and MySQL specific extensions. """ - + if self.charset: charset = 'CHARACTER SET %s' % self.charset elif self.ascii: @@ -112,7 +279,7 @@ class _StringType(object): collation = 'BINARY' else: collation = None - + if self.national: # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. return ' '.join([c for c in ('NATIONAL', spec, collation) @@ -123,7 +290,7 @@ class _StringType(object): def __repr__(self): attributes = inspect.getargspec(self.__init__)[0][1:] attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) - + params = {} for attr in attributes: val = getattr(self, attr) @@ -131,12 +298,13 @@ class _StringType(object): params[attr] = val return "%s(%s)" % (self.__class__.__name__, - ','.join(['%s=%s' % (k, params[k]) for k in params])) + ', '.join(['%s=%r' % (k, params[k]) for k in params])) + class MSNumeric(sqltypes.Numeric, _NumericType): - """MySQL NUMERIC type""" - - def __init__(self, precision = 10, length = 2, asdecimal=True, **kw): + """MySQL NUMERIC type.""" + + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a NUMERIC. precision @@ -157,24 +325,30 @@ class MSNumeric(sqltypes.Numeric, _NumericType): _NumericType.__init__(self, **kw) sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal) - + def get_col_spec(self): if self.precision is None: return self._extend("NUMERIC") else: return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if not self.asdecimal and isinstance(value, Decimal): - return float(value) + def result_processor(self, dialect): + if not self.asdecimal: + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process else: - return value + return None + class MSDecimal(MSNumeric): - """MySQL DECIMAL type""" + """MySQL DECIMAL type.""" def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DECIMAL. @@ -196,7 +370,7 @@ class MSDecimal(MSNumeric): """ super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw) - + def get_col_spec(self): if self.precision is None: return self._extend("DECIMAL") @@ -205,10 +379,11 @@ class MSDecimal(MSNumeric): else: return self._extend("DECIMAL(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) -class MSDouble(MSNumeric): - """MySQL DOUBLE type""" - def __init__(self, precision=10, length=2, asdecimal=True, **kw): +class MSDouble(sqltypes.Float, _NumericType): + """MySQL DOUBLE type.""" + + def __init__(self, precision=None, length=None, asdecimal=True, **kw): """Construct a DOUBLE. precision @@ -229,8 +404,14 @@ class MSDouble(MSNumeric): if ((precision is None and length is not None) or (precision is not None and length is None)): - raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.") - super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw) + raise exceptions.ArgumentError( + "You must specify both precision and length or omit " + "both altogether.") + + _NumericType.__init__(self, **kw) + sqltypes.Float.__init__(self, asdecimal=asdecimal) + self.length = length + self.precision = precision def get_col_spec(self): if self.precision is not None and self.length is not None: @@ -240,12 +421,45 @@ class MSDouble(MSNumeric): else: return self._extend('DOUBLE') + +class MSReal(MSDouble): + """MySQL REAL type.""" + + def __init__(self, precision=None, length=None, asdecimal=True, **kw): + """Construct a REAL. + + precision + Total digits in this number. If length and precision are both None, + values are stored to limits allowed by the server. + + length + The number of digits after the decimal point. + + unsigned + Optional. + + zerofill + Optional. If true, values will be stored as strings left-padded with + zeros. Note that this does not effect the values returned by the + underlying database API, which continue to be numeric. + """ + MSDouble.__init__(self, precision, length, asdecimal, **kw) + + def get_col_spec(self): + if self.precision is not None and self.length is not None: + return self._extend("REAL(%(precision)s, %(length)s)" % + {'precision': self.precision, + 'length' : self.length}) + else: + return self._extend('REAL') + + class MSFloat(sqltypes.Float, _NumericType): - """MySQL FLOAT type""" + """MySQL FLOAT type.""" - def __init__(self, precision=10, length=None, asdecimal=False, **kw): + def __init__(self, precision=None, length=None, asdecimal=False, **kw): """Construct a FLOAT. - + precision Total digits in this number. If length and precision are both None, values are stored to limits allowed by the server. @@ -262,25 +476,25 @@ class MSFloat(sqltypes.Float, _NumericType): underlying database API, which continue to be numeric. """ - if length is not None: - self.length=length _NumericType.__init__(self, **kw) - sqltypes.Float.__init__(self, precision, asdecimal=asdecimal) + sqltypes.Float.__init__(self, asdecimal=asdecimal) + self.length = length + self.precision = precision def get_col_spec(self): - if hasattr(self, 'length') and self.length is not None: - return self._extend("FLOAT(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) + if self.length is not None and self.precision is not None: + return self._extend("FLOAT(%s, %s)" % (self.precision, self.length)) elif self.precision is not None: - return self._extend("FLOAT(%(precision)s)" % {'precision': self.precision}) + return self._extend("FLOAT(%s)" % (self.precision,)) else: return self._extend("FLOAT") - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None class MSInteger(sqltypes.Integer, _NumericType): - """MySQL INTEGER type""" + """MySQL INTEGER type.""" def __init__(self, length=None, **kw): """Construct an INTEGER. @@ -307,8 +521,9 @@ class MSInteger(sqltypes.Integer, _NumericType): else: return self._extend("INTEGER") + class MSBigInteger(MSInteger): - """MySQL BIGINTEGER type""" + """MySQL BIGINTEGER type.""" def __init__(self, length=None, **kw): """Construct a BIGINTEGER. @@ -333,8 +548,40 @@ class MSBigInteger(MSInteger): else: return self._extend("BIGINT") -class MSSmallInteger(sqltypes.Smallinteger, _NumericType): - """MySQL SMALLINTEGER type""" + +class MSTinyInteger(MSInteger): + """MySQL TINYINT type.""" + + def __init__(self, length=None, **kw): + """Construct a TINYINT. + + Note: following the usual MySQL conventions, TINYINT(1) columns + reflected during Table(..., autoload=True) are treated as + Boolean columns. + + length + Optional, maximum display width for this number. + + unsigned + Optional. + + zerofill + Optional. If true, values will be stored as strings left-padded with + zeros. Note that this does not effect the values returned by the + underlying database API, which continue to be numeric. + """ + + super(MSTinyInteger, self).__init__(length, **kw) + + def get_col_spec(self): + if self.length is not None: + return self._extend("TINYINT(%s)" % self.length) + else: + return self._extend("TINYINT") + + +class MSSmallInteger(sqltypes.Smallinteger, MSInteger): + """MySQL SMALLINTEGER type.""" def __init__(self, length=None, **kw): """Construct a SMALLINTEGER. @@ -353,7 +600,7 @@ class MSSmallInteger(sqltypes.Smallinteger, _NumericType): self.length = length _NumericType.__init__(self, **kw) - sqltypes.Smallinteger.__init__(self, length) + sqltypes.SmallInteger.__init__(self, length) def get_col_spec(self): if self.length is not None: @@ -361,59 +608,104 @@ class MSSmallInteger(sqltypes.Smallinteger, _NumericType): else: return self._extend("SMALLINT") + +class MSBit(sqltypes.TypeEngine): + """MySQL BIT type. + + This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater for + MyISAM, MEMORY, InnoDB and BDB. For older versions, use a MSTinyInteger(1) + type. + """ + + def __init__(self, length=None): + self.length = length + + def result_processor(self, dialect): + """Convert a MySQL's 64 bit, variable length binary string to a long.""" + def process(value): + if value is not None: + v = 0L + for i in map(ord, value): + v = v << 8 | i + value = v + return value + return process + + def get_col_spec(self): + if self.length is not None: + return "BIT(%s)" % self.length + else: + return "BIT" + + class MSDateTime(sqltypes.DateTime): - """MySQL DATETIME type""" + """MySQL DATETIME type.""" def get_col_spec(self): return "DATETIME" + class MSDate(sqltypes.Date): - """MySQL DATE type""" + """MySQL DATE type.""" def get_col_spec(self): return "DATE" + class MSTime(sqltypes.Time): - """MySQL TIME type""" + """MySQL TIME type.""" def get_col_spec(self): return "TIME" - def convert_result_value(self, value, dialect): - # convert from a timedelta value - if value is not None: - return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) - else: - return None + def result_processor(self, dialect): + def process(value): + # convert from a timedelta value + if value is not None: + return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) + else: + return None + return process class MSTimeStamp(sqltypes.TIMESTAMP): - """MySQL TIMESTAMP type + """MySQL TIMESTAMP type. To signal the orm to automatically re-select modified rows to retrieve - the timestamp, add a PassiveDefault to your column specification: + the updated timestamp, add a PassiveDefault to your column specification:: from sqlalchemy.databases import mysql - Column('updated', mysql.MSTimeStamp, PassiveDefault(text('CURRENT_TIMESTAMP()'))) + Column('updated', mysql.MSTimeStamp, + PassiveDefault(sql.text('CURRENT_TIMESTAMP'))) + + The full range of MySQL 4.1+ TIMESTAMP defaults can be specified in + the PassiveDefault:: + + PassiveDefault(sql.text('CURRENT TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')) + """ def get_col_spec(self): return "TIMESTAMP" -class MSYear(sqltypes.String): - """MySQL YEAR type, for single byte storage of years 1901-2155""" + +class MSYear(sqltypes.TypeEngine): + """MySQL YEAR type, for single byte storage of years 1901-2155.""" + + def __init__(self, length=None): + self.length = length def get_col_spec(self): if self.length is None: return "YEAR" else: - return "YEAR(%d)" % self.length + return "YEAR(%s)" % self.length + +class MSText(_StringType, sqltypes.Text): + """MySQL TEXT type, for text up to 2^16 characters.""" -class MSText(_StringType, sqltypes.TEXT): - """MySQL TEXT type, for text up to 2^16 characters""" - def __init__(self, length=None, **kwargs): """Construct a TEXT. - + length Optional, if provided the server may optimize storage by subsitituting the smallest TEXT type sufficient to store @@ -447,22 +739,22 @@ class MSText(_StringType, sqltypes.TEXT): """ _StringType.__init__(self, **kwargs) - sqltypes.TEXT.__init__(self, length, - kwargs.get('convert_unicode', False)) + sqltypes.Text.__init__(self, length, + kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None)) def get_col_spec(self): if self.length: return self._extend("TEXT(%d)" % self.length) else: return self._extend("TEXT") - + class MSTinyText(MSText): - """MySQL TINYTEXT type, for text up to 2^8 characters""" + """MySQL TINYTEXT type, for text up to 2^8 characters.""" def __init__(self, **kwargs): """Construct a TINYTEXT. - + charset Optional, a column-level character set for this string value. Takes precendence to 'ascii' or 'unicode' short-hand. @@ -495,12 +787,13 @@ class MSTinyText(MSText): def get_col_spec(self): return self._extend("TINYTEXT") + class MSMediumText(MSText): - """MySQL MEDIUMTEXT type, for text up to 2^24 characters""" + """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" def __init__(self, **kwargs): """Construct a MEDIUMTEXT. - + charset Optional, a column-level character set for this string value. Takes precendence to 'ascii' or 'unicode' short-hand. @@ -533,12 +826,13 @@ class MSMediumText(MSText): def get_col_spec(self): return self._extend("MEDIUMTEXT") + class MSLongText(MSText): - """MySQL LONGTEXT type, for text up to 2^32 characters""" + """MySQL LONGTEXT type, for text up to 2^32 characters.""" def __init__(self, **kwargs): """Construct a LONGTEXT. - + charset Optional, a column-level character set for this string value. Takes precendence to 'ascii' or 'unicode' short-hand. @@ -571,12 +865,13 @@ class MSLongText(MSText): def get_col_spec(self): return self._extend("LONGTEXT") + class MSString(_StringType, sqltypes.String): """MySQL VARCHAR type, for variable-length character data.""" def __init__(self, length=None, **kwargs): """Construct a VARCHAR. - + length Maximum data length, in characters. @@ -609,7 +904,7 @@ class MSString(_StringType, sqltypes.String): _StringType.__init__(self, **kwargs) sqltypes.String.__init__(self, length, - kwargs.get('convert_unicode', False)) + kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None)) def get_col_spec(self): if self.length: @@ -617,12 +912,13 @@ class MSString(_StringType, sqltypes.String): else: return self._extend("TEXT") + class MSChar(_StringType, sqltypes.CHAR): """MySQL CHAR type, for fixed-length character data.""" - + def __init__(self, length, **kwargs): """Construct an NCHAR. - + length Maximum data length, in characters. @@ -642,14 +938,17 @@ class MSChar(_StringType, sqltypes.CHAR): def get_col_spec(self): return self._extend("CHAR(%(length)s)" % {'length' : self.length}) + class MSNVarChar(_StringType, sqltypes.String): - """MySQL NVARCHAR type, for variable-length character data in the - server's configured national character set. + """MySQL NVARCHAR type. + + For variable-length character data in the server's configured national + character set. """ def __init__(self, length=None, **kwargs): """Construct an NVARCHAR. - + length Maximum data length, in characters. @@ -672,10 +971,13 @@ class MSNVarChar(_StringType, sqltypes.String): # We'll actually generate the equiv. "NATIONAL VARCHAR" instead # of "NVARCHAR". return self._extend("VARCHAR(%(length)s)" % {'length': self.length}) - + + class MSNChar(_StringType, sqltypes.CHAR): - """MySQL NCHAR type, for fixed-length character data in the - server's configured national character set. + """MySQL NCHAR type. + + For fixed-length character data in the server's configured national + character set. """ def __init__(self, length=None, **kwargs): @@ -701,8 +1003,9 @@ class MSNChar(_StringType, sqltypes.CHAR): # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". return self._extend("CHAR(%(length)s)" % {'length': self.length}) + class _BinaryType(sqltypes.Binary): - """MySQL binary types""" + """Base for MySQL binary types.""" def get_col_spec(self): if self.length: @@ -710,14 +1013,16 @@ class _BinaryType(sqltypes.Binary): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process class MSVarBinary(_BinaryType): - """MySQL VARBINARY type, for variable length binary data""" + """MySQL VARBINARY type, for variable length binary data.""" def __init__(self, length=None, **kw): """Construct a VARBINARY. Arguments are: @@ -733,6 +1038,7 @@ class MSVarBinary(_BinaryType): else: return "BLOB" + class MSBinary(_BinaryType): """MySQL BINARY type, for fixed length binary data""" @@ -742,7 +1048,7 @@ class MSBinary(_BinaryType): pad value. length - Maximum data length, in bytes. If not length is specified, this + Maximum data length, in bytes. If length is not specified, this will generate a BLOB. This usage is deprecated. """ @@ -754,15 +1060,16 @@ class MSBinary(_BinaryType): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process class MSBlob(_BinaryType): - """MySQL BLOB type, for binary data up to 2^16 bytes""" - + """MySQL BLOB type, for binary data up to 2^16 bytes""" def __init__(self, length=None, **kw): """Construct a BLOB. Arguments are: @@ -781,49 +1088,57 @@ class MSBlob(_BinaryType): else: return "BLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return buffer(value) + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return buffer(value) + return process def __repr__(self): return "%s()" % self.__class__.__name__ + class MSTinyBlob(MSBlob): - """MySQL TINYBLOB type, for binary data up to 2^8 bytes""" + """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" def get_col_spec(self): return "TINYBLOB" -class MSMediumBlob(MSBlob): - """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes""" + +class MSMediumBlob(MSBlob): + """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" def get_col_spec(self): return "MEDIUMBLOB" + class MSLongBlob(MSBlob): - """MySQL LONGBLOB type, for binary data up to 2^32 bytes""" + """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" def get_col_spec(self): return "LONGBLOB" + class MSEnum(MSString): """MySQL ENUM type.""" - + def __init__(self, *enums, **kw): - """ - Construct an ENUM. + """Construct an ENUM. Example: Column('myenum', MSEnum("'foo'", "'bar'", "'baz'")) Arguments are: - + enums The range of valid values for this ENUM. Values will be used - exactly as they appear when generating schemas + exactly as they appear when generating schemas. Strings must + be quoted, as in the example above. Single-quotes are suggested + for ANSI compatability and are required for portability to servers + with ANSI_QUOTES enabled. strict Defaults to False: ensure that a given value is in this ENUM's @@ -854,102 +1169,216 @@ class MSEnum(MSString): schema. This does not affect the type of data stored, only the collation of character data. """ - + self.__ddl_values = enums strip_enums = [] for a in enums: if a[0:1] == '"' or a[0:1] == "'": - a = a[1:-1] + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) strip_enums.append(a) - + self.enums = strip_enums self.strict = kw.pop('strict', False) - length = max([len(v) for v in strip_enums]) + length = max([len(v) for v in strip_enums] + [0]) super(MSEnum, self).__init__(length, **kw) - def convert_bind_param(self, value, engine): - if self.strict and value is not None and value not in self.enums: - raise exceptions.InvalidRequestError('"%s" not a valid value for ' - 'this enum' % value) - return super(MSEnum, self).convert_bind_param(value, engine) + def bind_processor(self, dialect): + super_convert = super(MSEnum, self).bind_processor(dialect) + def process(value): + if self.strict and value is not None and value not in self.enums: + raise exceptions.InvalidRequestError('"%s" not a valid value for ' + 'this enum' % value) + if super_convert: + return super_convert(value) + else: + return value + return process def get_col_spec(self): return self._extend("ENUM(%s)" % ",".join(self.__ddl_values)) + +class MSSet(MSString): + """MySQL SET type.""" + + def __init__(self, *values, **kw): + """Construct a SET. + + Example:: + + Column('myset', MSSet("'foo'", "'bar'", "'baz'")) + + Arguments are: + + values + The range of valid values for this SET. Values will be used + exactly as they appear when generating schemas. Strings must + be quoted, as in the example above. Single-quotes are suggested + for ANSI compatability and are required for portability to servers + with ANSI_QUOTES enabled. + + charset + Optional, a column-level character set for this string + value. Takes precendence to 'ascii' or 'unicode' short-hand. + + collation + Optional, a column-level collation for this string value. + Takes precedence to 'binary' short-hand. + + ascii + Defaults to False: short-hand for the ``latin1`` character set, + generates ASCII in schema. + + unicode + Defaults to False: short-hand for the ``ucs2`` character set, + generates UNICODE in schema. + + binary + Defaults to False: short-hand, pick the binary collation type + that matches the column's character set. Generates BINARY in + schema. This does not affect the type of data stored, only the + collation of character data. + """ + + self.__ddl_values = values + + strip_values = [] + for a in values: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_values.append(a) + + self.values = strip_values + length = max([len(v) for v in strip_values] + [0]) + super(MSSet, self).__init__(length, **kw) + + def result_processor(self, dialect): + def process(value): + # The good news: + # No ',' quoting issues- commas aren't allowed in SET values + # The bad news: + # Plenty of driver inconsistencies here. + if isinstance(value, util.set_types): + # ..some versions convert '' to an empty set + if not value: + value.add('') + # ..some return sets.Set, even for pythons that have __builtin__.set + if not isinstance(value, util.Set): + value = util.Set(value) + return value + # ...and some versions return strings + if value is not None: + return util.Set(value.split(',')) + else: + return value + return process + + def bind_processor(self, dialect): + super_convert = super(MSSet, self).bind_processor(dialect) + def process(value): + if value is None or isinstance(value, (int, long, basestring)): + pass + else: + if None in value: + value = util.Set(value) + value.remove(None) + value.add('') + value = ','.join(value) + if super_convert: + return super_convert(value) + else: + return value + return process + + def get_col_spec(self): + return self._extend("SET(%s)" % ",".join(self.__ddl_values)) + + class MSBoolean(sqltypes.Boolean): + """MySQL BOOLEAN type.""" + def get_col_spec(self): return "BOOL" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False - -# TODO: SET, BIT + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process colspecs = { - sqltypes.Integer : MSInteger, - sqltypes.Smallinteger : MSSmallInteger, - sqltypes.Numeric : MSNumeric, - sqltypes.Float : MSFloat, - sqltypes.DateTime : MSDateTime, - sqltypes.Date : MSDate, - sqltypes.Time : MSTime, - sqltypes.String : MSString, - sqltypes.Binary : MSBlob, - sqltypes.Boolean : MSBoolean, - sqltypes.TEXT : MSText, + sqltypes.Integer: MSInteger, + sqltypes.Smallinteger: MSSmallInteger, + sqltypes.Numeric: MSNumeric, + sqltypes.Float: MSFloat, + sqltypes.DateTime: MSDateTime, + sqltypes.Date: MSDate, + sqltypes.Time: MSTime, + sqltypes.String: MSString, + sqltypes.Binary: MSBlob, + sqltypes.Boolean: MSBoolean, + sqltypes.Text: MSText, sqltypes.CHAR: MSChar, sqltypes.NCHAR: MSNChar, sqltypes.TIMESTAMP: MSTimeStamp, sqltypes.BLOB: MSBlob, + MSDouble: MSDouble, + MSReal: MSReal, _BinaryType: _BinaryType, } - +# Everything 3.23 through 5.1 excepting OpenGIS types. ischema_names = { - 'bigint' : MSBigInteger, - 'binary' : MSBinary, - 'blob' : MSBlob, + 'bigint': MSBigInteger, + 'binary': MSBinary, + 'bit': MSBit, + 'blob': MSBlob, 'boolean':MSBoolean, - 'char' : MSChar, - 'date' : MSDate, - 'datetime' : MSDateTime, - 'decimal' : MSDecimal, - 'double' : MSDouble, + 'char': MSChar, + 'date': MSDate, + 'datetime': MSDateTime, + 'decimal': MSDecimal, + 'double': MSDouble, 'enum': MSEnum, 'fixed': MSDecimal, - 'float' : MSFloat, - 'int' : MSInteger, - 'integer' : MSInteger, + 'float': MSFloat, + 'int': MSInteger, + 'integer': MSInteger, 'longblob': MSLongBlob, 'longtext': MSLongText, 'mediumblob': MSMediumBlob, - 'mediumint' : MSInteger, + 'mediumint': MSInteger, 'mediumtext': MSMediumText, 'nchar': MSNChar, 'nvarchar': MSNVarChar, - 'numeric' : MSNumeric, - 'smallint' : MSSmallInteger, - 'text' : MSText, - 'time' : MSTime, - 'timestamp' : MSTimeStamp, + 'numeric': MSNumeric, + 'set': MSSet, + 'smallint': MSSmallInteger, + 'text': MSText, + 'time': MSTime, + 'timestamp': MSTimeStamp, 'tinyblob': MSTinyBlob, - 'tinyint' : MSSmallInteger, - 'tinytext' : MSTinyText, - 'varbinary' : MSVarBinary, - 'varchar' : MSString, + 'tinyint': MSTinyInteger, + 'tinytext': MSTinyText, + 'varbinary': MSVarBinary, + 'varchar': MSString, + 'year': MSYear, } def descriptor(): @@ -962,49 +1391,74 @@ def descriptor(): ('host',"Hostname", None), ]} + class MySQLExecutionContext(default.DefaultExecutionContext): def post_exec(self): - if self.compiled.isinsert: - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - - def is_select(self): - return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None - -class MySQLDialect(ansisql.ANSIDialect): - def __init__(self, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) - self.per_connection = weakref.WeakKeyDictionary() + if self.compiled.isinsert and not self.executemany: + if (not len(self._last_inserted_ids) or + self._last_inserted_ids[0] is None): + self._last_inserted_ids = ([self.cursor.lastrowid] + + self._last_inserted_ids[1:]) + elif (not self.isupdate and not self.should_autocommit and + self.statement and SET_RE.match(self.statement)): + # This misses if a user forces autocommit on text('SET NAMES'), + # which is probably a programming error anyhow. + self.connection.info.pop(('mysql', 'charset'), None) + + def returns_rows_text(self, statement): + return SELECT_RE.match(statement) + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_RE.match(statement) + + +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. Not used directly in application code.""" + + supports_alter = True + supports_unicode_statements = False + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + supports_sane_rowcount = True + default_paramstyle = 'format' + + def __init__(self, use_ansiquotes=None, **kwargs): + self.use_ansiquotes = use_ansiquotes + default.DefaultDialect.__init__(self, **kwargs) def dbapi(cls): import MySQLdb as mysql return mysql dbapi = classmethod(dbapi) - + def create_connect_args(self, url): - opts = url.translate_connect_args(['host', 'db', 'user', 'passwd', 'port']) + opts = url.translate_connect_args(database='db', username='user', + password='passwd') opts.update(url.query) util.coerce_kw_type(opts, 'compress', bool) util.coerce_kw_type(opts, 'connect_timeout', int) util.coerce_kw_type(opts, 'client_flag', int) util.coerce_kw_type(opts, 'local_infile', int) - # note: these two could break SA Unicode type - util.coerce_kw_type(opts, 'use_unicode', bool) + # Note: using either of the below will cause all strings to be returned + # as Unicode, both in raw SQL operations and with column types like + # String and MSString. + util.coerce_kw_type(opts, 'use_unicode', bool) util.coerce_kw_type(opts, 'charset', str) - # TODO: cursorclass and conv: support via query string or punt? - - # ssl + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + ssl = {} for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: if key in opts: ssl[key[4:]] = opts[key] util.coerce_kw_type(ssl, key[4:], str) del opts[key] - if len(ssl): + if ssl: opts['ssl'] = ssl - - # FOUND_ROWS must be set in CLIENT_FLAGS for to enable + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. client_flag = opts.get('client_flag', 0) if self.dbapi is not None: @@ -1016,221 +1470,237 @@ class MySQLDialect(ansisql.ANSIDialect): opts['client_flag'] = client_flag return [[], opts] - def create_execution_context(self, *args, **kwargs): - return MySQLExecutionContext(self, *args, **kwargs) + def create_execution_context(self, connection, **kwargs): + return MySQLExecutionContext(self, connection, **kwargs) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - # identifiers are 64, however aliases can be 255... - def max_identifier_length(self): - return 255; - - def supports_sane_rowcount(self): - return True - - def compiler(self, statement, bindparams, **kwargs): - return MySQLCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return MySQLSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return MySQLSchemaDropper(self, *args, **kwargs) - - def preparer(self): - return MySQLIdentifierPreparer(self) - - def do_executemany(self, cursor, statement, parameters, context=None, **kwargs): + def do_executemany(self, cursor, statement, parameters, context=None): rowcount = cursor.executemany(statement, parameters) if context is not None: context._rowcount = rowcount - + def supports_unicode_statements(self): return True - - def do_execute(self, cursor, statement, parameters, **kwargs): + + def do_execute(self, cursor, statement, parameters, context=None): cursor.execute(statement, parameters) + def do_commit(self, connection): + """Execute a COMMIT.""" + + # COMMIT/ROLLBACK were introduced in 3.23.15. + # Yes, we have at least one user who has to talk to these old versions! + # + # Ignore commit/rollback if support isn't present, otherwise even basic + # operations via autocommit fail. + try: + connection.commit() + except: + if self._server_version_info(connection) < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + def do_rollback(self, connection): - # MySQL without InnoDB doesnt support rollback() + """Execute a ROLLBACK.""" + try: connection.rollback() except: - pass + if self._server_version_info(connection) < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise def do_begin_twophase(self, connection, xid): - connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute("XA BEGIN %s", xid) def do_prepare_twophase(self, connection, xid): - connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) - connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute("XA END %s", xid) + connection.execute("XA PREPARE %s", xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + def do_rollback_twophase(self, connection, xid, is_prepared=True, + recover=False): if not is_prepared: - connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) - connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute("XA END %s", xid) + connection.execute("XA ROLLBACK %s", xid) - def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + def do_commit_twophase(self, connection, xid, is_prepared=True, + recover=False): if not is_prepared: self.do_prepare_twophase(connection, xid) - connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)])) - + connection.execute("XA COMMIT %s", xid) + def do_recover_twophase(self, connection): - resultset = connection.execute(sql.text("XA RECOVER")) + resultset = connection.execute("XA RECOVER") return [row['data'][0:row['gtrid_length']] for row in resultset] + def do_ping(self, connection): + connection.ping() + def is_disconnect(self, e): - return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055) + if isinstance(e, self.dbapi.OperationalError): + return e.args[0] in (2006, 2013, 2014, 2045, 2055) + elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, this is the error you get + return "(0, '')" in str(e) + else: + return False def get_default_schema_name(self, connection): - try: - return self._default_schema_name - except AttributeError: - name = self._default_schema_name = \ - connection.execute('SELECT DATABASE()').scalar() - return name + return connection.execute('SELECT DATABASE()').scalar() + get_default_schema_name = engine_base.connection_memoize( + ('dialect', 'default_schema_name'))(get_default_schema_name) + + def table_names(self, connection, schema): + """Return a Unicode SHOW TABLES from a given schema.""" + + charset = self._detect_charset(connection) + self._autoset_identifier_style(connection) + rp = connection.execute("SHOW TABLES FROM %s" % + self.identifier_preparer.quote_identifier(schema)) + return [row[0] for row in _compat_fetchall(rp, charset=charset)] def has_table(self, connection, table_name, schema=None): # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly # on macosx (and maybe win?) with multibyte table names. # # TODO: if this is not a problem on win, make the strategy swappable - # based on platform. DESCRIBE is much slower. - if schema is not None: - st = "DESCRIBE `%s`.`%s`" % (schema, table_name) - else: - st = "DESCRIBE `%s`" % table_name + # based on platform. DESCRIBE is slower. + + # [ticket:726] + # full_name = self.identifier_preparer.format_table(table, + # use_schema=True) + + self._autoset_identifier_style(connection) + + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + + st = "DESCRIBE %s" % full_name + rs = None try: - rs = connection.execute(st) - have = rs.rowcount > 0 - rs.close() - return have - except exceptions.SQLError, e: - if e.orig.args[0] == 1146: - return False - raise + try: + rs = connection.execute(st) + have = rs.rowcount > 0 + rs.close() + return have + except exceptions.SQLError, e: + if e.orig.args[0] == 1146: + return False + raise + finally: + if rs: + rs.close() + + def server_version_info(self, connection): + """A tuple of the database server version. + + Formats the remote server version as a tuple of version values, + e.g. ``(5, 0, 44)``. If there are strings in the version number + they will be in the tuple too, so don't count on these all being + ``int`` values. + + This is a fast check that does not require a round trip. It is also + cached per-Connection. + """ + + return self._server_version_info(connection.connection.connection) + server_version_info = engine_base.connection_memoize( + ('mysql', 'server_version_info'))(server_version_info) + + def _server_version_info(self, dbapi_con): + """Convert a MySQL-python server_info string into a tuple.""" - def get_version_info(self, connectable): - if hasattr(connectable, 'connect'): - con = connectable.connect().connection - else: - con = connectable version = [] - for n in con.get_server_info().split('.'): + r = re.compile('[.\-]') + for n in r.split(dbapi_con.get_server_info()): try: version.append(int(n)) except ValueError: version.append(n) return tuple(version) + # @deprecated + def get_version_info(self, connectable): + """A tuple of the database server version. + + Deprecated, use ``server_version_info()``. + """ + + if isinstance(connectable, engine_base.Engine): + connectable = connectable.contextual_connect() + + return self.server_version_info(connectable) + get_version_info = util.deprecated()(get_version_info) + def reflecttable(self, connection, table, include_columns): """Load column definitions from the server.""" - decode_from = self._detect_charset(connection) - case_sensitive = self._detect_case_sensitive(connection, decode_from) - - if not case_sensitive: - table.name = table.name.lower() - table.metadata.tables[table.name]= table + charset = self._detect_charset(connection) + self._autoset_identifier_style(connection) try: - rp = connection.execute("DESCRIBE " + self._escape_table_name(table)) - except exceptions.SQLError, e: - if e.orig.args[0] == 1146: - raise exceptions.NoSuchTableError(table.fullname) - raise + reflector = self.reflector + except AttributeError: + preparer = self.identifier_preparer + if (self.server_version_info(connection) < (4, 1) and + self.use_ansiquotes): + # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 + preparer = MySQLIdentifierPreparer(self) - for row in _compat_fetchall(rp, charset=decode_from): - (name, type, nullable, primary_key, default) = \ - (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4]) + self.reflector = reflector = MySQLSchemaReflector(preparer) - # leave column names as unicode - name = name.decode(decode_from) - - if include_columns and name not in include_columns: - continue + sql = self._show_create_table(connection, table, charset) + if sql.startswith('CREATE ALGORITHM'): + # Adapt views to something table-like. + columns = self._describe_table(connection, table, charset) + sql = reflector._describe_to_create(table, columns) - match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type) - col_type = match.group(1) - args = match.group(2) - extra_1 = match.group(3) - extra_2 = match.group(4) + self._adjust_casing(connection, table) - try: - coltype = ischema_names[col_type] - except KeyError: - warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (col_type, name))) - coltype = sqltypes.NULLTYPE - - kw = {} - if extra_1 is not None: - kw[extra_1] = True - if extra_2 is not None: - kw[extra_2] = True - - if args is not None: - if col_type == 'enum': - args= args[1:-1] - argslist = args.split(',') - coltype = coltype(*argslist, **kw) - else: - argslist = re.findall(r'(\d+)', args) - coltype = coltype(*[int(a) for a in argslist], **kw) + return reflector.reflect(connection, table, sql, charset, + only=include_columns) - colargs= [] - if default: - if col_type == 'timestamp' and default == 'CURRENT_TIMESTAMP': - default = sql.text(default) - colargs.append(schema.PassiveDefault(default)) - table.append_column(schema.Column(name, coltype, *colargs, - **dict(primary_key=primary_key, - nullable=nullable, - ))) + def _adjust_casing(self, connection, table, charset=None): + """Adjust Table name to the server case sensitivity, if needed.""" - tabletype = self.moretableinfo(connection, table, decode_from) - table.kwargs['mysql_engine'] = tabletype + casing = self._detect_casing(connection) - def moretableinfo(self, connection, table, charset=None): - """SHOW CREATE TABLE to get foreign key/table options.""" + # For winxx database hosts. TODO: is this really needed? + if casing == 1 and table.name != table.name.lower(): + table.name = table.name.lower() + lc_alias = schema._get_table_key(table.name, table.schema) + table.metadata.tables[lc_alias] = table - rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {}) - row = _compat_fetchone(rp, charset=charset) - if not row: - raise exceptions.NoSuchTableError(table.fullname) - desc = row[1].strip() - - tabletype = '' - lastparen = re.search(r'\)[^\)]*\Z', desc) - if lastparen: - match = re.search(r'\b(?:TYPE|ENGINE)=(?P.+)\b', desc[lastparen.start():], re.I) - if match: - tabletype = match.group('ttype') - - # \x27 == ' (single quote) (avoid xemacs syntax highlighting issue) - fkpat = r'''CONSTRAINT [`"\x27](?P.+?)[`"\x27] FOREIGN KEY \((?P.+?)\) REFERENCES [`"\x27](?P.+?)[`"\x27] \((?P.+?)\)''' - for match in re.finditer(fkpat, desc): - columns = re.findall(r'''[`"\x27](.+?)[`"\x27]''', match.group('columns')) - refcols = [match.group('reftable') + "." + x for x in re.findall(r'''[`"\x27](.+?)[`"\x27]''', match.group('refcols'))] - schema.Table(match.group('reftable'), table.metadata, autoload=True, autoload_with=connection) - constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name')) - table.append_constraint(constraint) - - return tabletype - - def _escape_table_name(self, table): - if table.schema is not None: - return '`%s`.`%s`' % (table.schema, table.name) - else: - return '`%s`' % table.name def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" + # Allow user override, won't sniff if force_charset is set. + if ('mysql', 'force_charset') in connection.info: + return connection.info[('mysql', 'force_charset')] + # Note: MySQL-python 1.2.1c7 seems to ignore changes made # on a connection via set_character_set() - - rs = connection.execute("show variables like 'character_set%%'") + if self.server_version_info(connection) < (4, 1, 0): + try: + return connection.connection.character_set_name() + except AttributeError: + # < 1.2.1 final MySQL-python drivers have no charset support. + # a query is needed. + pass + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)]) if 'character_set_results' in opts: @@ -1238,52 +1708,170 @@ class MySQLDialect(ansisql.ANSIDialect): try: return connection.connection.character_set_name() except AttributeError: - # < 1.2.1 final MySQL-python drivers have no charset support + # Still no charset on < 1.2.1 final... if 'character_set' in opts: return opts['character_set'] else: - warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python. MySQL-python >= 1.2.2 is recommended. Assuming latin1.")) + util.warn( + "Could not detect the connection character set with this " + "combination of MySQL server and MySQL-python. " + "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") return 'latin1' + _detect_charset = engine_base.connection_memoize( + ('mysql', 'charset'))(_detect_charset) - def _detect_case_sensitive(self, connection, charset=None): + + def _detect_casing(self, connection): """Sniff out identifier case sensitivity. Cached per-connection. This value can not change without a server restart. + """ # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html - _per_connection_mutex.acquire() + charset = self._detect_charset(connection) + row = _compat_fetchone(connection.execute( + "SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset) + if not row: + cs = 0 + else: + # 4.0.15 returns OFF or ON according to [ticket:489] + # 3.23 doesn't, 4.0.27 doesn't.. + if row[1] == 'OFF': + cs = 0 + elif row[1] == 'ON': + cs = 1 + else: + cs = int(row[1]) + row.close() + return cs + _detect_casing = engine_base.connection_memoize( + ('mysql', 'lower_case_table_names'))(_detect_casing) + + def _detect_collations(self, connection): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + collations = {} + if self.server_version_info(connection) < (4, 1, 0): + pass + else: + charset = self._detect_charset(connection) + rs = connection.execute('SHOW COLLATION') + for row in _compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + _detect_collations = engine_base.connection_memoize( + ('mysql', 'collations'))(_detect_collations) + + def use_ansiquotes(self, useansi): + self._use_ansiquotes = useansi + if useansi: + self.preparer = MySQLANSIIdentifierPreparer + else: + self.preparer = MySQLIdentifierPreparer + # icky + if hasattr(self, 'identifier_preparer'): + self.identifier_preparer = self.preparer(self) + if hasattr(self, 'reflector'): + del self.reflector + + use_ansiquotes = property(lambda s: s._use_ansiquotes, use_ansiquotes, + doc="True if ANSI_QUOTES is in effect.") + + def _autoset_identifier_style(self, connection, charset=None): + """Detect and adjust for the ANSI_QUOTES sql mode. + + If the dialect's use_ansiquotes is unset, query the server's sql mode + and reset the identifier style. + + Note that this currently *only* runs during reflection. Ideally this + would run the first time a connection pool connects to the database, + but the infrastructure for that is not yet in place. + """ + + if self.use_ansiquotes is not None: + return + + row = _compat_fetchone( + connection.execute("SHOW VARIABLES LIKE 'sql_mode'", + charset=charset)) + if not row: + mode = '' + else: + mode = row[1] or '' + # 4.0 + if mode.isdigit(): + mode_no = int(mode) + mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' + + self.use_ansiquotes = 'ANSI_QUOTES' in mode + + def _show_create_table(self, connection, table, charset=None, + full_name=None): + """Run SHOW CREATE TABLE for a ``Table``.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "SHOW CREATE TABLE %s" % full_name + + rp = None try: - raw_connection = connection.connection.connection - cache = self.per_connection.get(raw_connection, {}) - if 'lower_case_table_names' not in cache: - row = _compat_fetchone(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) - if not row: - cs = True + try: + rp = connection.execute(st) + except exceptions.SQLError, e: + if e.orig.args[0] == 1146: + raise exceptions.NoSuchTableError(full_name) else: - cs = row[1] in ('0', 'OFF' 'off') - cache['lower_case_table_names'] = cs - self.per_connection[raw_connection] = cache - return cache.get('lower_case_table_names') + raise + row = _compat_fetchone(rp, charset=charset) + if not row: + raise exceptions.NoSuchTableError(full_name) + return row[1].strip() finally: - _per_connection_mutex.release() + if rp: + rp.close() -def _compat_fetchall(rp, charset=None): - """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" + return sql - return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] + def _describe_table(self, connection, table, charset=None, + full_name=None): + """Run DESCRIBE for a ``Table`` and return processed rows.""" -def _compat_fetchone(rp, charset=None): - """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "DESCRIBE %s" % full_name - return _MySQLPythonRowProxy(rp.fetchone(), charset) - + rp, rows = None, None + try: + try: + rp = connection.execute(st) + except exceptions.SQLError, e: + if e.orig.args[0] == 1146: + raise exceptions.NoSuchTableError(full_name) + else: + raise + rows = _compat_fetchall(rp, charset=charset) + finally: + if rp: + rp.close() + return rows class _MySQLPythonRowProxy(object): - """Return consistent column values for all versions of MySQL-python (esp. alphas) and unicode settings.""" + """Return consistent column values for all versions of MySQL-python. + + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. + """ + + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. def __init__(self, rowproxy, charset): self.rowproxy = rowproxy @@ -1292,37 +1880,90 @@ class _MySQLPythonRowProxy(object): item = self.rowproxy[index] if isinstance(item, _array): item = item.tostring() - if self.charset and isinstance(item, unicode): - return item.encode(self.charset) + if self.charset and isinstance(item, str): + return item.decode(self.charset) else: return item def __getattr__(self, attr): item = getattr(self.rowproxy, attr) if isinstance(item, _array): item = item.tostring() - if self.charset and isinstance(item, unicode): - return item.encode(self.charset) + if self.charset and isinstance(item, str): + return item.decode(self.charset) else: return item -class MySQLCompiler(ansisql.ANSICompiler): - operators = ansisql.ANSICompiler.operators.copy() - operators.update( - { - sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y), - operator.mod : '%%' - } - ) +class MySQLCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() + operators.update({ + sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), + sql_operators.mod: '%%' + }) + functions = compiler.DefaultCompiler.functions.copy() + functions.update ({ + sql_functions.random: 'rand%(expr)s' + }) - def visit_cast(self, cast, **kwargs): - if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)): - return super(MySQLCompiler, self).visit_cast(cast, **kwargs) + + def visit_typeclause(self, typeclause): + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, MSInteger): + if getattr(type_, 'unsigned', False): + return 'UNSIGNED INTEGER' + else: + return 'SIGNED INTEGER' + elif isinstance(type_, (MSDecimal, MSDateTime, MSDate, MSTime)): + return type_.get_col_spec() + elif isinstance(type_, MSText): + return 'CHAR' + elif (isinstance(type_, _StringType) and not + isinstance(type_, (MSEnum, MSSet))): + if getattr(type_, 'length'): + return 'CHAR(%s)' % type_.length + else: + return 'CHAR' + elif isinstance(type_, _BinaryType): + return 'BINARY' + elif isinstance(type_, MSNumeric): + return type_.get_col_spec().replace('NUMERIC', 'DECIMAL') + elif isinstance(type_, MSTimeStamp): + return 'DATETIME' + elif isinstance(type_, (MSDateTime, MSDate, MSTime)): + return type_.get_col_spec() else: - # so just skip the CAST altogether for now. - # TODO: put whatever MySQL does for CAST here. + return None + + def visit_cast(self, cast, **kwargs): + # No cast until 4, no decimals until 5. + type_ = self.process(cast.typeclause) + if type_ is None: return self.process(cast.clause) + return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) + + + def get_select_precolumns(self, select): + if isinstance(select._distinct, basestring): + return select._distinct.upper() + " " + elif select._distinct: + return "DISTINCT " + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + # 'JOIN ... ON ...' for inner joins isn't available until 4.0. + # Apparently < 3.23.17 requires theta joins for inner joins + # (but not outer). Not generating these currently, but + # support can be added, preferably after dialects are + # refactored to be version-sensitive. + return ''.join( + (self.process(join.left, asfrom=True), + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), + self.process(join.right, asfrom=True), + " ON ", + self.process(join.onclause))) + def for_update_clause(self, select): if select.for_update == 'read': return ' LOCK IN SHARE MODE' @@ -1330,40 +1971,97 @@ class MySQLCompiler(ansisql.ANSICompiler): return super(MySQLCompiler, self).for_update_clause(select) def limit_clause(self, select): - text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - # straight from the MySQL docs, I kid you not - text += " \n LIMIT 18446744073709551615" - text += " OFFSET " + str(select._offset) + # MySQL supports: + # LIMIT + # LIMIT , + # and in server versions > 3.3: + # LIMIT OFFSET + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. + + limit, offset = select._limit, select._offset + + if (limit, offset) == (None, None): + return '' + elif offset is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + if limit is None: + limit = 18446744073709551615 + return ' \n LIMIT %s, %s' % (offset, limit) + else: + # No offset provided, so just use the limit + return ' \n LIMIT %s' % (limit,) + + def visit_update(self, update_stmt): + self.stack.append({'from':util.Set([update_stmt.table])}) + + self.isupdate = True + colparams = self._get_colparams(update_stmt) + + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) + + if update_stmt._whereclause: + text += " WHERE " + self.process(update_stmt._whereclause) + + limit = update_stmt.kwargs.get('mysql_limit', None) + if limit: + text += " LIMIT %s" % limit + + self.stack.pop(-1) + return text - -class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, first_pk=False): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() +# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. +# Starting with MySQL 4.1.2, these indexes are created automatically. +# In older versions, the indexes must be created explicitly or the +# creation of foreign key constraints fails." + +class MySQLSchemaGenerator(compiler.SchemaGenerator): + def get_column_specification(self, column, first_pk=False): + """Builds column DDL.""" + + colspec = [self.preparer.format_column(column), + column.type.dialect_impl(self.dialect, + _for_ddl=column).get_col_spec()] + default = self.get_column_default_string(column) if default is not None: - colspec += " DEFAULT " + default + colspec.append('DEFAULT ' + default) if not column.nullable: - colspec += " NOT NULL" - if column.primary_key: - if len(column.foreign_keys)==0 and first_pk and column.autoincrement and isinstance(column.type, sqltypes.Integer): - colspec += " AUTO_INCREMENT" - return colspec + colspec.append('NOT NULL') + + if column.primary_key and column.autoincrement: + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + isinstance(c.type, sqltypes.Integer) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('AUTO_INCREMENT') + except IndexError: + pass + + return ' '.join(colspec) def post_create_table(self, table): - args = "" + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] for k in table.kwargs: if k.startswith('mysql_'): - opt = k[6:] - args += " %s=%s" % (opt.upper(), table.kwargs[k]) - return args + opt = k[6:].upper() + joiner = '=' + if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', + 'CHARACTER SET', 'COLLATE'): + joiner = ' ' + + table_opts.append(joiner.join((opt, table.kwargs[k]))) + return ' '.join(table_opts) + -class MySQLSchemaDropper(ansisql.ANSISchemaDropper): +class MySQLSchemaDropper(compiler.SchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX %s ON %s" % (self.preparer.format_index(index), @@ -1376,18 +2074,632 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper): self.preparer.format_constraint(constraint))) self.execute() -class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): - def __init__(self, dialect): - super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote='`') - def _reserved_words(self): - return RESERVED_WORDS +class MySQLSchemaReflector(object): + """Parses SHOW CREATE TABLE output.""" + + def __init__(self, identifier_preparer): + """Construct a MySQLSchemaReflector. + + identifier_preparer + An ANSIIdentifierPreparer type, used to determine the identifier + quoting style in effect. + """ + + self.preparer = identifier_preparer + self._prep_regexes() + + def reflect(self, connection, table, show_create, charset, only=None): + """Parse MySQL SHOW CREATE TABLE and fill in a ''Table''. + + show_create + Unicode output of SHOW CREATE TABLE + + table + A ''Table'', to be loaded with Columns, Indexes, etc. + table.name will be set if not already + + charset + FIXME, some constructed values (like column defaults) + currently can't be Unicode. ''charset'' will convert them + into the connection character set. + + only + An optional sequence of column names. If provided, only + these columns will be reflected, and any keys or constraints + that include columns outside this set will also be omitted. + That means that if ``only`` includes only one column in a + 2 part primary key, the entire primary key will be omitted. + """ + + keys, constraints = [], [] + + if only: + only = util.Set(only) + + for line in re.split(r'\r?\n', show_create): + if line.startswith(' ' + self.preparer.initial_quote): + self._add_column(table, line, charset, only) + # a regular table options line + elif line.startswith(') '): + self._set_options(table, line) + # an ANSI-mode table options line + elif line == ')': + pass + elif line.startswith('CREATE '): + self._set_name(table, line) + # Not present in real reflection, but may be if loading from a file. + elif not line: + pass + else: + type_, spec = self.parse_constraints(line) + if type_ is None: + util.warn("Unknown schema content: %r" % line) + elif type_ == 'key': + keys.append(spec) + elif type_ == 'constraint': + constraints.append(spec) + else: + pass + + self._set_keys(table, keys, only) + self._set_constraints(table, constraints, connection, only) + + def _set_name(self, table, line): + """Override a Table name with the reflected name. + + table + A ``Table`` + + line + The first line of SHOW CREATE TABLE output. + """ + + # Don't override by default. + if table.name is None: + table.name = self.parse_name(line) + + def _add_column(self, table, line, charset, only=None): + spec = self.parse_column(line) + if not spec: + util.warn("Unknown column definition %r" % line) + return + if not spec['full']: + util.warn("Incomplete reflection of column definition %r" % line) + + name, type_, args, notnull = \ + spec['name'], spec['coltype'], spec['arg'], spec['notnull'] + + if only and name.lower() not in only: + self.logger.info("Omitting reflected column %s.%s" % + (table.name, name)) + return + + # Convention says that TINYINT(1) columns == BOOLEAN + if type_ == 'tinyint' and args == '1': + type_ = 'boolean' + args = None + + try: + col_type = ischema_names[type_] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (type_, name)) + col_type = sqltypes.NullType + + # Column type positional arguments eg. varchar(32) + if args is None or args == '': + type_args = [] + elif args[0] == "'" and args[-1] == "'": + type_args = self._re_csv_str.findall(args) + else: + type_args = [int(v) for v in self._re_csv_int.findall(args)] + + # Column type keyword options + type_kw = {} + for kw in ('unsigned', 'zerofill'): + if spec.get(kw, False): + type_kw[kw] = True + for kw in ('charset', 'collate'): + if spec.get(kw, False): + type_kw[kw] = spec[kw] + + type_instance = col_type(*type_args, **type_kw) + + col_args, col_kw = [], {} + + # NOT NULL + if spec.get('notnull', False): + col_kw['nullable'] = False + + # AUTO_INCREMENT + if spec.get('autoincr', False): + col_kw['autoincrement'] = True + elif issubclass(col_type, sqltypes.Integer): + col_kw['autoincrement'] = False + + # DEFAULT + default = spec.get('default', None) + if default is not None and default != 'NULL': + # Defaults should be in the native charset for the moment + default = default.encode(charset) + if type_ == 'timestamp': + # can't be NULL for TIMESTAMPs + if (default[0], default[-1]) != ("'", "'"): + default = sql.text(default) + else: + default = default[1:-1] + col_args.append(schema.PassiveDefault(default)) + + table.append_column(schema.Column(name, type_instance, + *col_args, **col_kw)) + + def _set_keys(self, table, keys, only): + """Add ``Index`` and ``PrimaryKeyConstraint`` items to a ``Table``. + + Most of the information gets dropped here- more is reflected than + the schema objects can currently represent. + + table + A ``Table`` + + keys + A sequence of key specifications produced by `constraints` + + only + Optional `set` of column names. If provided, keys covering + columns not in this set will be omitted. + """ + + for spec in keys: + flavor = spec['type'] + col_names = [s[0] for s in spec['columns']] + + if only and not util.Set(col_names).issubset(only): + if flavor is None: + flavor = 'index' + self.logger.info( + "Omitting %s KEY for (%s), key covers ommitted columns." % + (flavor, ', '.join(col_names))) + continue + + constraint = False + if flavor == 'PRIMARY': + key = schema.PrimaryKeyConstraint() + constraint = True + elif flavor == 'UNIQUE': + key = schema.Index(spec['name'], unique=True) + elif flavor in (None, 'FULLTEXT', 'SPATIAL'): + key = schema.Index(spec['name']) + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY" % flavor) + key = schema.Index(spec['name']) + + for col in [table.c[name] for name in col_names]: + key.append_column(col) + + if constraint: + table.append_constraint(key) + + def _set_constraints(self, table, constraints, connection, only): + """Apply constraints to a ``Table``.""" + + for spec in constraints: + # only FOREIGN KEYs + ref_name = spec['table'][-1] + ref_schema = len(spec['table']) > 1 and spec['table'][-2] or None + + loc_names = spec['local'] + if only and not util.Set(loc_names).issubset(only): + self.logger.info( + "Omitting FOREIGN KEY for (%s), key covers ommitted " + "columns." % (', '.join(loc_names))) + continue + + ref_key = schema._get_table_key(ref_name, ref_schema) + if ref_key in table.metadata.tables: + ref_table = table.metadata.tables[ref_key] + else: + ref_table = schema.Table( + ref_name, table.metadata, schema=ref_schema, + autoload=True, autoload_with=connection) + + ref_names = spec['foreign'] + if not util.Set(ref_names).issubset( + util.Set([c.name for c in ref_table.c])): + raise exceptions.InvalidRequestError( + "Foreign key columns (%s) are not present on " + "foreign table %s" % + (', '.join(ref_names), ref_table.fullname())) + ref_columns = [ref_table.c[name] for name in ref_names] + + con_kw = {} + for opt in ('name', 'onupdate', 'ondelete'): + if spec.get(opt, False): + con_kw[opt] = spec[opt] + + key = schema.ForeignKeyConstraint([], [], **con_kw) + table.append_constraint(key) + for pair in zip(loc_names, ref_columns): + key.append_element(*pair) + + def _set_options(self, table, line): + """Apply safe reflected table options to a ``Table``. + + table + A ``Table`` + + line + The final line of SHOW CREATE TABLE output. + """ + + options = self.parse_table_options(line) + for nope in ('auto_increment', 'data_directory', 'index_directory'): + options.pop(nope, None) + + for opt, val in options.items(): + table.kwargs['mysql_%s' % opt] = val + + def _prep_regexes(self): + """Pre-compile regular expressions.""" + + self._re_columns = [] + self._pr_options = [] + self._re_options_util = {} + + _final = self.preparer.final_quote + + quotes = dict(zip(('iq', 'fq', 'esc_fq'), + [re.escape(s) for s in + (self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final))])) + + self._pr_name = _pr_compile( + r'^CREATE TABLE +' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes, + self.preparer._unescape_identifier) + + # `col`,`col2`(32),`col3`(15) DESC + # + # Note: ASC and DESC aren't reflected, so we'll punt... + self._re_keyexprs = _re_compile( + r'(?:' + r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)' + r'(?:\((\d+)\))?(?=\,|$))+' % quotes) + + # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' + self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27') + + # 123 or 123,456 + self._re_csv_int = _re_compile(r'\d+') + + + # `colname` [type opts] + # (NOT NULL | NULL) + # DEFAULT ('value' | CURRENT_TIMESTAMP...) + # COMMENT 'comment' + # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT) + # STORAGE (DISK|MEMORY) + self._re_column = _re_compile( + r' ' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'(?P\w+)' + r'(?:\((?P(?:\d+|\d+,\d+|' + r'(?:\x27(?:\x27\x27|[^\x27])*\x27,?)+))\))?' + r'(?: +(?PUNSIGNED))?' + r'(?: +(?PZEROFILL))?' + r'(?: +CHARACTER SET +(?P\w+))?' + r'(?: +COLLATE +(P\w+))?' + r'(?: +(?PNOT NULL))?' + r'(?: +DEFAULT +(?P' + r'(?:NULL|\x27(?:\x27\x27|[^\x27])*\x27|\w+)' + r'(?:ON UPDATE \w+)?' + r'))?' + r'(?: +(?PAUTO_INCREMENT))?' + r'(?: +COMMENT +(P(?:\x27\x27|[^\x27])+))?' + r'(?: +COLUMN_FORMAT +(?P\w+))?' + r'(?: +STORAGE +(?P\w+))?' + r'(?: +(?P.*))?' + r',?$' + % quotes + ) + + # Fallback, try to parse as little as possible + self._re_column_loose = _re_compile( + r' ' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'(?P\w+)' + r'(?:\((?P(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' + r'.*?(?PNOT NULL)?' + % quotes + ) + + # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? + # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) + # KEY_BLOCK_SIZE size | WITH PARSER name + self._re_key = _re_compile( + r' ' + r'(?:(?P\S+) )?KEY' + r'(?: +%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?' + r'(?: +USING +(?P\S+))?' + r' +\((?P.+?)\)' + r'(?: +KEY_BLOCK_SIZE +(?P\S+))?' + r'(?: +WITH PARSER +(?P\S+))?' + r',?$' + % quotes + ) + + # CONSTRAINT `name` FOREIGN KEY (`local_col`) + # REFERENCES `remote` (`remote_col`) + # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE + # ON DELETE CASCADE ON UPDATE RESTRICT + # + # unique constraints come back as KEYs + kw = quotes.copy() + kw['on'] = 'RESTRICT|CASCASDE|SET NULL|NOACTION' + self._re_constraint = _re_compile( + r' ' + r'CONSTRAINT +' + r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' + r'FOREIGN KEY +' + r'\((?P[^\)]+?)\) REFERENCES +' + r'(?P%(iq)s[^%(fq)s]+%(fq)s) +' + r'\((?P[^\)]+?)\)' + r'(?: +(?PMATCH \w+))?' + r'(?: +ON DELETE (?P%(on)s))?' + r'(?: +ON UPDATE (?P%(on)s))?' + % kw + ) + + # PARTITION + # + # punt! + self._re_partition = _re_compile( + r' ' + r'(?:SUB)?PARTITION') + + # Table-level options (COLLATE, ENGINE, etc.) + for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT', + 'AVG_ROW_LENGTH', 'CHARACTER SET', + 'DEFAULT CHARSET', 'CHECKSUM', + 'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD', + 'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT', + 'KEY_BLOCK_SIZE'): + self._add_option_word(option) + + for option in (('COMMENT', 'DATA_DIRECTORY', 'INDEX_DIRECTORY', + 'PASSWORD', 'CONNECTION')): + self._add_option_string(option) + + self._add_option_regex('UNION', r'\([^\)]+\)') + self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK') + self._add_option_regex('RAID_TYPE', + r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') + self._re_options_util['='] = _re_compile(r'\s*=\s*$') + + def _add_option_string(self, directive): + regex = (r'(?P%s\s*(?:=\s*)?)' + r'(?:\x27.(?P.*?)\x27(?!\x27)\x27)' % + re.escape(directive)) + self._pr_options.append( + _pr_compile(regex, lambda v: v.replace("''", "'"))) + + def _add_option_word(self, directive): + regex = (r'(?P%s\s*(?:=\s*)?)' + r'(?P\w+)' % re.escape(directive)) + self._pr_options.append(_pr_compile(regex)) + + def _add_option_regex(self, directive, regex): + regex = (r'(?P%s\s*(?:=\s*)?)' + r'(?P%s)' % (re.escape(directive), regex)) + self._pr_options.append(_pr_compile(regex)) + + + def parse_name(self, line): + """Extract the table name. + + line + The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if not m: + return None + return cleanup(m.group('name')) + + def parse_column(self, line): + """Extract column details. + + Falls back to a 'minimal support' variant if full parse fails. + + line + Any column-bearing line from SHOW CREATE TABLE + """ + + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec['full'] = True + return spec + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec['full'] = False + return spec + return None + + def parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + line + A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + spec['columns'] = self._parse_keyexprs(spec['columns']) + return 'key', spec + + # CONSTRAINT + m = self._re_constraint.match(line) + if m: + spec = m.groupdict() + spec['table'] = \ + self.preparer.unformat_identifiers(spec['table']) + spec['local'] = [c[0] + for c in self._parse_keyexprs(spec['local'])] + spec['foreign'] = [c[0] + for c in self._parse_keyexprs(spec['foreign'])] + return 'constraint', spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return 'partition', line + + # No match. + return (None, line) + + def parse_table_options(self, line): + """Build a dictionary of all reflected table-level options. + + line + The final line of SHOW CREATE TABLE output. + """ + + options = {} + + if not line or line == ')': + return options + + r_eq_trim = self._re_options_util['='] + + for regex, cleanup in self._pr_options: + m = regex.search(line) + if not m: + continue + directive, value = m.group('directive'), m.group('val') + directive = r_eq_trim.sub('', directive).lower() + if cleanup: + value = cleanup(value) + options[directive] = value + + return options + + def _describe_to_create(self, table, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. + + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. + + `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. + """ + + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = \ + [row[i] for i in (0, 1, 2, 4, 5)] + + line = [' '] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append('NOT NULL') + if default: + if 'auto_increment' in default: + pass + elif (col_type.startswith('timestamp') and + default.startswith('C')): + line.append('DEFAULT') + line.append(default) + elif default == 'NULL': + line.append('DEFAULT') + line.append(default) + else: + line.append('DEFAULT') + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) + + buffer.append(' '.join(line)) + + return ''.join([('CREATE TABLE %s (\n' % + self.preparer.quote_identifier(table.name)), + ',\n'.join(buffer), + '\n) ']) + + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" + + return self._re_keyexprs.findall(identifiers) + +MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector) + + +class _MySQLIdentifierPreparer(compiler.IdentifierPreparer): + """MySQL-specific schema identifier configuration.""" + + reserved_words = RESERVED_WORDS + + def __init__(self, dialect, **kw): + super(_MySQLIdentifierPreparer, self).__init__(dialect, **kw) + + def _quote_free_identifiers(self, *ids): + """Unilaterally identifier-quote any number of strings.""" + + return tuple([self.quote_identifier(i) for i in ids if i is not None]) + + +class MySQLIdentifierPreparer(_MySQLIdentifierPreparer): + """Traditional MySQL-specific schema identifier configuration.""" + + def __init__(self, dialect): + super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote="`") def _escape_identifier(self, value): return value.replace('`', '``') - def _fold_identifier_case(self, value): - #TODO: determine MySQL's case folding rules - return value + def _unescape_identifier(self, value): + return value.replace('``', '`') + + +class MySQLANSIIdentifierPreparer(_MySQLIdentifierPreparer): + """ANSI_QUOTES MySQL schema identifier configuration.""" + + pass + + +def _compat_fetchall(rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" + + return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] + +def _compat_fetchone(rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" + + return _MySQLPythonRowProxy(rp.fetchone(), charset) + +def _pr_compile(regex, cleanup=None): + """Prepare a 2-tuple of compiled regex and callable.""" + + return (_re_compile(regex), cleanup) + +def _re_compile(regex): + """Compile a string to regex, I and UNICODE.""" + + return re.compile(regex, re.I | re.UNICODE) dialect = MySQLDialect +dialect.statement_compiler = MySQLCompiler +dialect.schemagenerator = MySQLSchemaGenerator +dialect.schemadropper = MySQLSchemaDropper diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index d3aa2e268f..734ad58d10 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -1,17 +1,17 @@ # oracle.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, warnings, operator +import datetime, random, re -from sqlalchemy import util, sql, schema, ansisql, exceptions, logging +from sqlalchemy import util, sql, schema, exceptions, logging from sqlalchemy.engine import default, base -import sqlalchemy.types as sqltypes - -import datetime +from sqlalchemy.sql import compiler, visitors +from sqlalchemy.sql import operators as sql_operators, functions as sql_functions +from sqlalchemy import types as sqltypes class OracleNumeric(sqltypes.Numeric): @@ -32,25 +32,30 @@ class OracleSmallInteger(sqltypes.Smallinteger): class OracleDate(sqltypes.Date): def get_col_spec(self): return "DATE" - def convert_bind_param(self, value, dialect): - return value - def convert_result_value(self, value, dialect): - if not isinstance(value, datetime.datetime): - return value - else: - return value.date() + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + def process(value): + if not isinstance(value, datetime.datetime): + return value + else: + return value.date() + return process class OracleDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATE" - - def convert_result_value(self, value, dialect): - if value is None or isinstance(value,datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, - value.day,value.hour, value.minute, value.second) + + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value,datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year,value.month, + value.day,value.hour, value.minute, value.second) + return process # Note: # Oracle DATE == DATETIME @@ -65,39 +70,43 @@ class OracleTimestamp(sqltypes.TIMESTAMP): def get_dbapi_type(self, dialect): return dialect.TIMESTAMP - def convert_result_value(self, value, dialect): - if value is None or isinstance(value,datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, - value.day,value.hour, value.minute, value.second) - + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value,datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year,value.month, + value.day,value.hour, value.minute, value.second) + return process class OracleString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} -class OracleText(sqltypes.TEXT): +class OracleText(sqltypes.Text): def get_dbapi_type(self, dbapi): return dbapi.CLOB def get_col_spec(self): return "CLOB" - def convert_result_value(self, value, dialect): - if value is None: - return None - elif hasattr(value, 'read'): - # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str - return super(OracleText, self).convert_result_value(value.read(), dialect) - else: - return super(OracleText, self).convert_result_value(value, dialect) - + def result_processor(self, dialect): + super_process = super(OracleText, self).result_processor(dialect) + lob = dialect.dbapi.LOB + def process(value): + if isinstance(value, lob): + if super_process: + return super_process(value.read()) + else: + return value.read() + else: + if super_process: + return super_process(value) + else: + return value + return process -class OracleRaw(sqltypes.Binary): - def get_col_spec(self): - return "RAW(%(length)s)" % {'length' : self.length} class OracleChar(sqltypes.CHAR): def get_col_spec(self): @@ -110,33 +119,44 @@ class OracleBinary(sqltypes.Binary): def get_col_spec(self): return "BLOB" - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if value is None: - return None - else: - return value.read() + def result_processor(self, dialect): + lob = dialect.dbapi.LOB + def process(value): + if isinstance(value, lob): + return value.read() + else: + return value + return process + +class OracleRaw(OracleBinary): + def get_col_spec(self): + return "RAW(%(length)s)" % {'length' : self.length} class OracleBoolean(sqltypes.Boolean): def get_col_spec(self): return "SMALLINT" - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False - - def convert_bind_param(self, value, dialect): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: + def result_processor(self, dialect): + def process(value): + if value is None: + return None return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process colspecs = { sqltypes.Integer : OracleInteger, @@ -148,14 +168,14 @@ colspecs = { sqltypes.String : OracleString, sqltypes.Binary : OracleBinary, sqltypes.Boolean : OracleBoolean, - sqltypes.TEXT : OracleText, + sqltypes.Text : OracleText, sqltypes.TIMESTAMP : OracleTimestamp, sqltypes.CHAR: OracleChar, } ischema_names = { 'VARCHAR2' : OracleString, - 'DATE' : OracleDate, + 'DATE' : OracleDateTime, 'DATETIME' : OracleDateTime, 'NUMBER' : OracleNumeric, 'BLOB' : OracleBinary, @@ -181,22 +201,25 @@ class OracleExecutionContext(default.DefaultExecutionContext): super(OracleExecutionContext, self).pre_exec() if self.dialect.auto_setinputsizes: self.set_input_sizes() - if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list): - for key in self.compiled_parameters: - (bindparam, name, value) = self.compiled_parameters.get_parameter(key) + if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: + for key in self.compiled.binds: + bindparam = self.compiled.binds[key] + name = self.compiled.bind_names[bindparam] + value = self.compiled_parameters[0][name] if bindparam.isoutparam: dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if not hasattr(self, 'out_parameters'): self.out_parameters = {} self.out_parameters[name] = self.cursor.var(dbtype) - self.parameters[name] = self.out_parameters[name] + self.parameters[0][name] = self.out_parameters[name] def get_result_proxy(self): if hasattr(self, 'out_parameters'): - if self.compiled_parameters is not None: - for k in self.out_parameters: - type = self.compiled_parameters.get_type(k) - self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect) + if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: + for bind, name in self.compiled.bind_names.iteritems(): + if name in self.out_parameters: + type = bind.type + self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue()) else: for k in self.out_parameters: self.out_parameters[k] = self.out_parameters[k].getvalue() @@ -206,43 +229,54 @@ class OracleExecutionContext(default.DefaultExecutionContext): type_code = column[1] if type_code in self.dialect.ORACLE_BINARY_TYPES: return base.BufferedColumnResultProxy(self) - + return base.ResultProxy(self) -class OracleDialect(ansisql.ANSIDialect): - def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs) +class OracleDialect(default.DefaultDialect): + supports_alter = True + supports_unicode_statements = False + max_identifier_length = 30 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + preexecute_pk_sequences = True + supports_pk_autoincrement = False + default_paramstyle = 'named' + + def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) self.use_ansi = use_ansi self.threaded = threaded + self.allow_twophase = allow_twophase self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) self.auto_setinputsizes = auto_setinputsizes self.auto_convert_lobs = auto_convert_lobs - - if self.dbapi is not None: - self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)] - else: + if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__: + self.dbapi_type_map = {} self.ORACLE_BINARY_TYPES = [] - - def dbapi_type_map(self): - if self.dbapi is None or not self.auto_convert_lobs: - return {} else: - return { - self.dbapi.NUMBER: OracleInteger(), - self.dbapi.CLOB: OracleText(), - self.dbapi.BLOB: OracleBinary(), - self.dbapi.STRING: OracleString(), - self.dbapi.TIMESTAMP: OracleTimestamp(), - self.dbapi.BINARY: OracleRaw(), - datetime.datetime: OracleDate() + # only use this for LOB objects. using it for strings, dates + # etc. leads to a little too much magic, reflection doesn't know if it should + # expect encoded strings or unicodes, etc. + self.dbapi_type_map = { + self.dbapi.CLOB: OracleText(), + self.dbapi.BLOB: OracleBinary(), + self.dbapi.BINARY: OracleRaw(), } + self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] def dbapi(cls): import cx_Oracle return cx_Oracle dbapi = classmethod(dbapi) - + def create_connect_args(self, url): + dialect_opts = dict(url.query) + for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', + 'threaded', 'allow_twophase'): + if opt in dialect_opts: + util.coerce_kw_type(dialect_opts, opt, bool) + setattr(self, opt, dialect_opts[opt]) + if url.database: # if we have a database, then we have a remote host port = url.port @@ -250,131 +284,175 @@ class OracleDialect(ansisql.ANSIDialect): port = int(port) else: port = 1521 - dsn = self.dbapi.makedsn(url.host,port,url.database) + dsn = self.dbapi.makedsn(url.host, port, url.database) else: # we have a local tnsname dsn = url.host + opts = dict( user=url.username, password=url.password, - dsn = dsn, - threaded = self.threaded + dsn=dsn, + threaded=self.threaded, + twophase=self.allow_twophase, ) - opts.update(url.query) - util.coerce_kw_type(opts, 'use_ansi', bool) + if 'mode' in url.query: + opts['mode'] = url.query['mode'] + if isinstance(opts['mode'], basestring): + mode = opts['mode'].upper() + if mode == 'SYSDBA': + opts['mode'] = self.dbapi.SYSDBA + elif mode == 'SYSOPER': + opts['mode'] = self.dbapi.SYSOPER + else: + util.coerce_kw_type(opts, 'mode', int) + # Can't set 'handle' or 'pool' via URL query args, use connect_args + return ([], opts) + def is_disconnect(self, e): + if isinstance(e, self.dbapi.InterfaceError): + return "not connected" in str(e) + else: + return "ORA-03114" in str(e) or "ORA-03113" in str(e) + def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def supports_unicode_statements(self): - """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" - return False - - def max_identifier_length(self): - return 30 - def oid_column_name(self, column): if not isinstance(column.table, (sql.TableClause, sql.Select)): return None else: return "rowid" - def create_execution_context(self, *args, **kwargs): - return OracleExecutionContext(self, *args, **kwargs) + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified.""" + + id = random.randint(0,2**128) + return (0x1234, "%032x" % 9, "%032x" % id) + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) - def compiler(self, statement, bindparams, **kwargs): - return OracleCompiler(self, statement, bindparams, **kwargs) + def do_prepare_twophase(self, connection, xid): + connection.connection.prepare() - def schemagenerator(self, *args, **kwargs): - return OracleSchemaGenerator(self, *args, **kwargs) + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_rollback(connection.connection) - def schemadropper(self, *args, **kwargs): - return OracleSchemaDropper(self, *args, **kwargs) + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_commit(connection.connection) - def defaultrunner(self, connection, **kwargs): - return OracleDefaultRunner(connection, **kwargs) + def do_recover_twophase(self, connection): + pass + + def create_execution_context(self, *args, **kwargs): + return OracleExecutionContext(self, *args, **kwargs) def has_table(self, connection, table_name, schema=None): - cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()}) + cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':self._denormalize_name(table_name)}) return bool( cursor.fetchone() is not None ) def has_sequence(self, connection, sequence_name): - cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name""", {'name':sequence_name.upper()}) + cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name""", {'name':self._denormalize_name(sequence_name)}) return bool( cursor.fetchone() is not None ) - def _locate_owner_row(self, owner, name, rows, raiseerr=False): - """return the row in the given list of rows which references the given table name and owner name.""" - if not rows: - if raiseerr: - raise exceptions.NoSuchTableError(name) - else: - return None + def _normalize_name(self, name): + if name is None: + return None + elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)): + return name.lower().decode(self.encoding) else: - if owner is not None: - for row in rows: - if owner.upper() in row[0]: - return row - else: - if raiseerr: - raise exceptions.AssertionError("Specified owner %s does not own table %s" % (owner, name)) - else: - return None - else: - if len(rows)==1: - return rows[0] - else: - if raiseerr: - raise exceptions.AssertionError("There are multiple tables with name '%s' visible to the schema, you must specifiy owner" % name) - else: - return None - - def _resolve_table_owner(self, connection, name, table, dblink=''): - """Locate the given table in the ``ALL_TAB_COLUMNS`` view, - including searching for equivalent synonyms and dblinks. + return name.decode(self.encoding) + + def _denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): + return name.upper().encode(self.encoding) + else: + return name.encode(self.encoding) + + def get_default_schema_name(self, connection): + return connection.execute('SELECT USER FROM DUAL').scalar() + get_default_schema_name = base.connection_memoize( + ('dialect', 'default_schema_name'))(get_default_schema_name) + + def table_names(self, connection, schema): + # note that table_names() isnt loading DBLINKed or synonym'ed tables + if schema is None: + s = "select table_name from all_tables where tablespace_name NOT IN ('SYSTEM', 'SYSAUX')" + cursor = connection.execute(s) + else: + s = "select table_name from all_tables where tablespace_name NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner" + cursor = connection.execute(s,{'owner':self._denormalize_name(schema)}) + return [self._normalize_name(row[0]) for row in cursor] + + def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): + """search for a local synonym matching the given desired owner/name. + + if desired_owner is None, attempts to locate a distinct owner. + + returns the actual name, owner, dblink name, and synonym name if found. """ - c = connection.execute ("select distinct OWNER from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name" % {'dblink':dblink}, {'table_name':name}) - rows = c.fetchall() - try: - row = self._locate_owner_row(table.owner, name, rows, raiseerr=True) - return name, row['OWNER'], '' - except exceptions.SQLAlchemyError: - # locate synonyms - c = connection.execute ("""select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK - from ALL_SYNONYMS%(dblink)s - where SYNONYM_NAME = :synonym_name - and (DB_LINK IS NOT NULL - or ((TABLE_NAME, TABLE_OWNER) in - (select TABLE_NAME, OWNER from ALL_TAB_COLUMNS%(dblink)s)))""" % {'dblink':dblink}, - {'synonym_name':name}) - rows = c.fetchall() - row = self._locate_owner_row(table.owner, name, rows) - if row is None: - row = self._locate_owner_row("PUBLIC", name, rows) - - if row is not None: - owner, name, dblink = row['TABLE_OWNER'], row['TABLE_NAME'], row['DB_LINK'] - if dblink: - dblink = '@' + dblink - if not owner: - # re-resolve table owner using new dblink variable - t1, owner, t2 = self._resolve_table_owner(connection, name, table, dblink=dblink) - else: - dblink = '' - return name, owner, dblink - raise + sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME + from ALL_SYNONYMS WHERE """ + + clauses = [] + params = {} + if desired_synonym: + clauses.append("SYNONYM_NAME=:synonym_name") + params['synonym_name'] = desired_synonym + if desired_owner: + clauses.append("TABLE_OWNER=:desired_owner") + params['desired_owner'] = desired_owner + if desired_table: + clauses.append("TABLE_NAME=:tname") + params['tname'] = desired_table + + sql += " AND ".join(clauses) + + result = connection.execute(sql, **params) + if desired_owner: + row = result.fetchone() + if row: + return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME'] + else: + return None, None, None, None + else: + rows = result.fetchall() + if len(rows) > 1: + raise exceptions.AssertionError("There are multiple tables visible to the schema, you must specify owner") + elif len(rows) == 1: + row = rows[0] + return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME'] + else: + return None, None, None, None def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer - if not preparer.should_quote(table): - name = table.name.upper() + + resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False) + + if resolve_synonyms: + actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name)) else: - name = table.name + actual_name, owner, dblink, synonym = None, None, None, None - # search for table, including across synonyms and dblinks. - # locate the actual name of the table, the real owner, and any dblink clause needed. - actual_name, owner, dblink = self._resolve_table_owner(connection, name, table) + if not actual_name: + actual_name = self._denormalize_name(table.name) + if not dblink: + dblink = '' + if not owner: + owner = self._denormalize_name(table.schema) or self.get_default_schema_name(connection) c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner}) @@ -385,11 +463,7 @@ class OracleDialect(ansisql.ANSIDialect): found_table = True #print "ROW:" , row - (colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) - - # if name comes back as all upper, assume its case folded - if (colname.upper() == colname): - colname = colname.lower() + (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) if include_columns and colname not in include_columns: continue @@ -413,7 +487,8 @@ class OracleDialect(ansisql.ANSIDialect): try: coltype = ischema_names[coltype] except KeyError: - warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, colname))) + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, colname)) coltype = sqltypes.NULLTYPE colargs = [] @@ -422,16 +497,16 @@ class OracleDialect(ansisql.ANSIDialect): table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) - if not len(table.columns): + if not table.columns: raise exceptions.AssertionError("Couldn't find any column information for table %s" % actual_name) c = connection.execute("""SELECT ac.constraint_name, ac.constraint_type, - LOWER(loc.column_name) AS local_column, - LOWER(rem.table_name) AS remote_table, - LOWER(rem.column_name) AS remote_column, - LOWER(rem.owner) AS remote_owner + loc.column_name AS local_column, + rem.table_name AS remote_table, + rem.column_name AS remote_column, + rem.owner AS remote_owner FROM all_constraints%(dblink)s ac, all_cons_columns%(dblink)s loc, all_cons_columns%(dblink)s rem @@ -452,7 +527,7 @@ class OracleDialect(ansisql.ANSIDialect): if row is None: break #print "ROW:" , row - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) if cons_type == 'P': table.primary_key.add(table.c[local_column]) elif cons_type == 'R': @@ -463,10 +538,25 @@ class OracleDialect(ansisql.ANSIDialect): fks[cons_name] = fk if remote_table is None: # ticket 363 - warnings.warn("Got 'None' querying 'table_name' from all_cons_columns%(dblink)s - does the user have proper rights to the table?" % {'dblink':dblink}) + util.warn( + ("Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?") % {'dblink':dblink}) continue - refspec = ".".join([remote_table, remote_column]) - schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, owner=remote_owner) + + if resolve_synonyms: + ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(remote_owner), desired_table=self._denormalize_name(remote_table)) + if ref_synonym: + remote_table = self._normalize_name(ref_synonym) + remote_owner = self._normalize_name(ref_remote_owner) + + if not table.schema and self._denormalize_name(remote_owner) == owner: + refspec = ".".join([remote_table, remote_column]) + t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) + else: + refspec = ".".join([x for x in [remote_owner, remote_table, remote_column] if x]) + t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, schema=remote_owner, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) + if local_column not in fk[0]: fk[0].append(local_column) if refspec not in fk[1]: @@ -475,14 +565,6 @@ class OracleDialect(ansisql.ANSIDialect): for name, value in fks.iteritems(): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) - def do_executemany(self, c, statement, parameters, context=None): - rowcount = 0 - for param in parameters: - c.execute(statement, param) - rowcount += c.rowcount - if context is not None: - context._rowcount = rowcount - OracleDialect.logger = logging.class_logger(OracleDialect) @@ -490,24 +572,33 @@ class _OuterJoinColumn(sql.ClauseElement): __visit_name__ = 'outer_join_column' def __init__(self, column): self.column = column - -class OracleCompiler(ansisql.ANSICompiler): + def _get_from_objects(self, **kwargs): + return [] + +class OracleCompiler(compiler.DefaultCompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False. """ - operators = ansisql.ANSICompiler.operators.copy() + operators = compiler.DefaultCompiler.operators.copy() operators.update( { - operator.mod : lambda x, y:"mod(%s, %s)" % (x, y) + sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y) + } + ) + + functions = compiler.DefaultCompiler.functions.copy() + functions.update ( + { + sql_functions.now : 'CURRENT_TIMESTAMP' } ) def __init__(self, *args, **kwargs): super(OracleCompiler, self).__init__(*args, **kwargs) self.__wheres = {} - + def default_from(self): """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. @@ -521,58 +612,43 @@ class OracleCompiler(ansisql.ANSICompiler): def visit_join(self, join, **kwargs): if self.dialect.use_ansi: - return ansisql.ANSICompiler.visit_join(self, join, **kwargs) - - (where, parentjoin) = self.__wheres.get(join, (None, None)) - - class VisitOn(sql.ClauseVisitor): - def visit_binary(s, binary): - if binary.operator == operator.eq: - if binary.left.table is join.right: - binary.left = _OuterJoinColumn(binary.left) - elif binary.right.table is join.right: - binary.right = _OuterJoinColumn(binary.right) - - if where is not None: - self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin) + return compiler.DefaultCompiler.visit_join(self, join, **kwargs) else: - self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join) - - return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) - def get_whereclause(self, f): - if f in self.__wheres: - return self.__wheres[f][0] - else: - return None - + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + def visit_binary(binary): + if binary.operator == sql_operators.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + clauses.append(visitors.traverse(join.onclause, visit_binary=visit_binary, clone=True)) + else: + clauses.append(join.onclause) + + for f in froms: + visitors.traverse(f, visit_join=visit_join) + return sql.and_(*clauses) + def visit_outer_join_column(self, vc): return self.process(vc.column) + "(+)" - - def uses_sequences_for_inserts(self): - return True + + def visit_sequence(self, seq): + return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" def visit_alias(self, alias, asfrom=False, **kwargs): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - + if asfrom: - return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + alias.name + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name)) else: return self.process(alias.original, **kwargs) - def visit_insert(self, insert): - """``INSERT`` s are required to have the primary keys be explicitly present. - - Mapper will by default not put them in the insert statement - to comply with autoincrement fields that require they not be - present. so, put them all in for all primary key columns. - """ - - for c in insert.table.primary_key: - if not self.parameters.has_key(c.key): - self.parameters[c.key] = None - return ansisql.ANSICompiler.visit_insert(self, insert) - def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" pass @@ -582,28 +658,43 @@ class OracleCompiler(ansisql.ANSICompiler): so tries to wrap it in a subquery with ``row_number()`` criterion. """ - if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None): - # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.process(select._order_by_clause) - if not orderby: - orderby = select.oid_column - self.traverse(orderby) - orderby = self.process(orderby) + if not getattr(select, '_oracle_visit', None): + if not self.dialect.use_ansi: + if self.stack and 'from' in self.stack[-1]: + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + + froms = select._get_display_froms(existingfroms) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause: + select = select.where(whereclause) + select._oracle_visit = True - oldselect = select - select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None) - select._oracle_visit = True + if select._limit is not None or select._offset is not None: + # to use ROW_NUMBER(), an ORDER BY is required. + orderby = self.process(select._order_by_clause) + if not orderby: + orderby = list(select.oid_column.proxies)[0] + orderby = self.process(orderby) + + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None) + select._oracle_visit = True - limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select._offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select._offset) - if select._limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset)) - else: - limitselect.append_whereclause("ora_rn<=%d" % select._limit) - return self.process(limitselect) - else: - return ansisql.ANSICompiler.visit_select(self, select, **kwargs) + limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + if select._offset is not None: + limitselect.append_whereclause("ora_rn>%d" % select._offset) + if select._limit is not None: + limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset)) + else: + limitselect.append_whereclause("ora_rn<=%d" % select._limit) + select = limitselect + + kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) + return compiler.DefaultCompiler.visit_select(self, select, **kwargs) def limit_clause(self, select): return "" @@ -615,10 +706,10 @@ class OracleCompiler(ansisql.ANSICompiler): return super(OracleCompiler, self).for_update_clause(select) -class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): +class OracleSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -632,18 +723,25 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class OracleSchemaDropper(ansisql.ANSISchemaDropper): +class OracleSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): - self.append("DROP SEQUENCE %s" % sequence.name) + self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class OracleDefaultRunner(ansisql.ANSIDefaultRunner): - def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection) - return self.connection.execute(c).scalar() - +class OracleDefaultRunner(base.DefaultRunner): def visit_sequence(self, seq): - return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar() + return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {}) + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + def format_savepoint(self, savepoint): + name = re.sub(r'^_+', '', savepoint.ident) + return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + dialect = OracleDialect +dialect.statement_compiler = OracleCompiler +dialect.schemagenerator = OracleSchemaGenerator +dialect.schemadropper = OracleSchemaDropper +dialect.preparer = OracleIdentifierPreparer +dialect.defaultrunner = OracleDefaultRunner diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index b192c47788..605ce7272b 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -1,27 +1,41 @@ # postgres.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re, random, warnings, operator +"""Support for the PostgreSQL database. -from sqlalchemy import sql, schema, ansisql, exceptions +PostgreSQL supports partial indexes. To create them pass a posgres_where +option to the Index constructor:: + + Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10) + +PostgreSQL 8.2+ supports returning a result set from inserts and updates. +To use this pass the column/expression list to the postgres_returning +parameter when creating the queries:: + + raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), + postgres_returning=[empl.c.id, empl.c.salary]).execute().fetchall() +""" + +import random, re, string + +from sqlalchemy import sql, schema, exceptions, util from sqlalchemy.engine import base, default -import sqlalchemy.types as sqltypes -from sqlalchemy.databases import information_schema as ischema -from decimal import Decimal +from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes -try: - import mx.DateTime.DateTime as mxDateTime -except: - mxDateTime = None - class PGInet(sqltypes.TypeEngine): def get_col_spec(self): return "INET" +class PGMacAddr(sqltypes.TypeEngine): + def get_col_spec(self): + return "MACADDR" + class PGNumeric(sqltypes.Numeric): def get_col_spec(self): if not self.precision: @@ -29,15 +43,20 @@ class PGNumeric(sqltypes.Numeric): else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} - def convert_bind_param(self, value, dialect): - return value + def bind_processor(self, dialect): + return None - def convert_result_value(self, value, dialect): - if not self.asdecimal and isinstance(value, Decimal): - return float(value) + def result_processor(self, dialect): + if self.asdecimal: + return None else: - return value - + def process(value): + if isinstance(value, util.decimal_type): + return float(value) + else: + return value + return process + class PGFloat(sqltypes.Float): def get_col_spec(self): if not self.precision: @@ -74,7 +93,7 @@ class PGInterval(sqltypes.TypeEngine): def get_col_spec(self): return "INTERVAL" -class PGText(sqltypes.TEXT): +class PGText(sqltypes.Text): def get_col_spec(self): return "TEXT" @@ -94,36 +113,64 @@ class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" -class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): - def __init__(self, item_type): +class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): + def __init__(self, item_type, mutable=True): if isinstance(item_type, type): item_type = item_type() self.item_type = item_type - - def dialect_impl(self, dialect): + self.mutable = mutable + + def copy_value(self, value): + if value is None: + return None + elif self.mutable: + return list(value) + else: + return value + + def compare_values(self, x, y): + return x == y + + def is_mutable(self): + return self.mutable + + def dialect_impl(self, dialect, **kwargs): impl = self.__class__.__new__(self.__class__) impl.__dict__.update(self.__dict__) impl.item_type = self.item_type.dialect_impl(dialect) return impl - def convert_bind_param(self, value, dialect): - if value is None: - return value - def convert_item(item): - if isinstance(item, (list,tuple)): - return [convert_item(child) for child in item] - else: - return self.item_type.convert_bind_param(item, dialect) - return [convert_item(item) for item in value] - def convert_result_value(self, value, dialect): - if value is None: - return value - def convert_item(item): - if isinstance(item, list): - return [convert_item(child) for child in item] - else: - return self.item_type.convert_result_value(item, dialect) - # Could specialcase when item_type.convert_result_value is the default identity func - return [convert_item(item) for item in value] + + def bind_processor(self, dialect): + item_proc = self.item_type.bind_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list,tuple)): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process + + def result_processor(self, dialect): + item_proc = self.item_type.result_processor(dialect) + def process(value): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + if item_proc: + return item_proc(item) + else: + return item + return [convert_item(item) for item in value] + return process def get_col_spec(self): return self.item_type.get_col_spec() + '[]' @@ -138,7 +185,7 @@ colspecs = { sqltypes.String : PGString, sqltypes.Binary : PGBinary, sqltypes.Boolean : PGBoolean, - sqltypes.TEXT : PGText, + sqltypes.Text : PGText, sqltypes.CHAR: PGChar, } @@ -153,6 +200,7 @@ ischema_names = { 'float' : PGFloat, 'real' : PGFloat, 'inet': PGInet, + 'macaddr': PGMacAddr, 'double precision' : PGFloat, 'timestamp' : PGDateTime, 'timestamp with time zone' : PGDateTime, @@ -176,26 +224,70 @@ def descriptor(): ('host',"Hostname", None), ]} +SERVER_SIDE_CURSOR_RE = re.compile( + r'\s*SELECT', + re.I | re.UNICODE) + +SELECT_RE = re.compile( + r'\s*(?:SELECT|FETCH|(UPDATE|INSERT))', + re.I | re.UNICODE) + +RETURNING_RE = re.compile( + 'RETURNING', + re.I | re.UNICODE) + +# This finds if the RETURNING is not inside a quoted/commented values. Handles string literals, +# quoted identifiers, dollar quotes, SQL comments and C style multiline comments. This does not +# handle correctly nested C style quotes, lets hope no one does the following: +# UPDATE tbl SET x=y /* foo /* bar */ RETURNING */ +RETURNING_QUOTED_RE = re.compile( + """\s*(?:UPDATE|INSERT)\s + (?: # handle quoted and commented tokens separately + [^'"$/-] # non quote/comment character + | -(?!-) # a dash that does not begin a comment + | /(?!\*) # a slash that does not begin a comment + | "(?:[^"]|"")*" # quoted literal + | '(?:[^']|'')*' # quoted string + | \$(?P[^$]*)\$.*?\$(?P=dquote)\$ # dollar quotes + | --[^\\n]*(?=\\n) # SQL comment, leave out line ending as that counts as whitespace + # for the returning token + | /\*([^*]|\*(?!/))*\*/ # C style comment, doesn't handle nesting + )* + \sRETURNING\s""", re.I | re.UNICODE | re.VERBOSE) + class PGExecutionContext(default.DefaultExecutionContext): + def returns_rows_text(self, statement): + m = SELECT_RE.match(statement) + return m and (not m.group(1) or (RETURNING_RE.search(statement) + and RETURNING_QUOTED_RE.match(statement))) + + def returns_rows_compiled(self, compiled): + return isinstance(compiled.statement, expression.Selectable) or \ + ( + (compiled.isupdate or compiled.isinsert) and "postgres_returning" in compiled.statement.kwargs + ) - def _is_server_side(self): - return self.dialect.server_side_cursors and self.is_select() and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I) - def create_cursor(self): - if self._is_server_side(): + self.__is_server_side = \ + self.dialect.server_side_cursors and \ + ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)) \ + or \ + (not self.compiled and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))) + + if self.__is_server_side: # use server-side cursors: # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c" + hex(random.randint(0, 65535))[2:] - return self.connection.connection.cursor(ident) + ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) + return self._connection.connection.cursor(ident) else: - return self.connection.connection.cursor() - + return self._connection.connection.cursor() + def get_result_proxy(self): - if self._is_server_side(): + if self.__is_server_side: return base.BufferedRowResultProxy(self) else: return base.ResultProxy(self) - + def post_exec(self): if self.compiled.isinsert and self.last_inserted_ids is None: if not self.dialect.use_oids: @@ -208,51 +300,45 @@ class PGExecutionContext(default.DefaultExecutionContext): row = self.connection.execute(s).fetchone() self._last_inserted_ids = [v for v in row] super(PGExecutionContext, self).post_exec() - -class PGDialect(ansisql.ANSIDialect): - def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) + +class PGDialect(default.DefaultDialect): + supports_alter = True + supports_unicode_statements = False + max_identifier_length = 63 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + preexecute_pk_sequences = True + supports_pk_autoincrement = False + default_paramstyle = 'pyformat' + + def __init__(self, use_oids=False, server_side_cursors=False, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors - self.use_information_schema = use_information_schema - self.paramstyle = 'pyformat' - + def dbapi(cls): import psycopg2 as psycopg return psycopg dbapi = classmethod(dbapi) - + def create_connect_args(self, url): - opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) - if opts.has_key('port'): + opts = url.translate_connect_args(username='user') + if 'port' in opts: opts['port'] = int(opts['port']) opts.update(url.query) return ([], opts) - def create_execution_context(self, *args, **kwargs): return PGExecutionContext(self, *args, **kwargs) - def max_identifier_length(self): - return 63 - def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def compiler(self, statement, bindparams, **kwargs): - return PGCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return PGSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return PGSchemaDropper(self, *args, **kwargs) - def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) def do_prepare_twophase(self, connection, xid): - connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)])) def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): if is_prepared: @@ -260,7 +346,9 @@ class PGDialect(ansisql.ANSIDialect): #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions # Must find out a way how to make the dbapi not open a transaction. connection.execute(sql.text("ROLLBACK")) - connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)])) + connection.execute(sql.text("BEGIN")) + self.do_rollback(connection.connection) else: self.do_rollback(connection.connection) @@ -268,7 +356,9 @@ class PGDialect(ansisql.ANSIDialect): if is_prepared: if recover: connection.execute(sql.text("ROLLBACK")) - connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)])) + connection.execute(sql.text("BEGIN")) + self.do_rollback(connection.connection) else: self.do_commit(connection.connection) @@ -276,16 +366,10 @@ class PGDialect(ansisql.ANSIDialect): resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] - def defaultrunner(self, context, **kwargs): - return PGDefaultRunner(context, **kwargs) - - def preparer(self): - return PGIdentifierPreparer(self) - def get_default_schema_name(self, connection): - if not hasattr(self, '_default_schema_name'): - self._default_schema_name = connection.scalar("select current_schema()", None) - return self._default_schema_name + return connection.scalar("select current_schema()", None) + get_default_schema_name = base.connection_memoize( + ('dialect', 'default_schema_name'))(get_default_schema_name) def last_inserted_ids(self): if self.context.last_inserted_ids is None: @@ -299,19 +383,6 @@ class PGDialect(ansisql.ANSIDialect): else: return None - def do_executemany(self, c, statement, parameters, context=None): - """We need accurate rowcounts for updates, inserts and deletes. - - ``psycopg2`` is not nice enough to produce this correctly for - an executemany, so we do our own executemany here. - """ - rowcount = 0 - for param in parameters: - c.execute(statement, param) - rowcount += c.rowcount - if context is not None: - context._rowcount = rowcount - def has_table(self, connection, table_name, schema=None): # seems like case gets folded in pg_class... if schema is None: @@ -321,186 +392,213 @@ class PGDialect(ansisql.ANSIDialect): return bool( not not cursor.rowcount ) def has_sequence(self, connection, sequence_name): - cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name}) + cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)}) return bool(not not cursor.rowcount) def is_disconnect(self, e): if isinstance(e, self.dbapi.OperationalError): return 'closed the connection' in str(e) or 'connection not open' in str(e) elif isinstance(e, self.dbapi.InterfaceError): - return 'connection already closed' in str(e) + return 'connection already closed' in str(e) or 'cursor already closed' in str(e) elif isinstance(e, self.dbapi.ProgrammingError): # yes, it really says "losed", not "closed" return "losed the connection unexpectedly" in str(e) else: return False + def table_names(self, connection, schema): + s = """ + SELECT relname + FROM pg_class c + WHERE relkind = 'r' + AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) + """ % locals() + return [row[0].decode(self.encoding) for row in connection.execute(s)] + + def server_version_info(self, connection): + v = connection.execute("select version()").scalar() + m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v) + if not m: + raise exceptions.AssertionError("Could not determine version from string '%s'" % v) + return tuple([int(x) for x in m.group(1, 2, 3)]) + def reflecttable(self, connection, table, include_columns): - if self.use_information_schema: - ischema.reflecttable(connection, table, include_columns, ischema_names) + preparer = self.identifier_preparer + if table.schema is not None: + schema_where_clause = "n.nspname = :schema" + schemaname = table.schema + if isinstance(schemaname, str): + schemaname = schemaname.decode(self.encoding) else: - preparer = self.identifier_preparer - if table.schema is not None: - schema_where_clause = "n.nspname = :schema" - else: - schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - - ## information schema in pg suffers from too many permissions' restrictions - ## let us find out at the pg way what is needed... - - SQL_COLS = """ - SELECT a.attname, - pg_catalog.format_type(a.atttypid, a.atttypmod), - (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) - AS DEFAULT, - a.attnotnull, a.attnum, a.attrelid as table_oid - FROM pg_catalog.pg_attribute a - WHERE a.attrelid = ( - SELECT c.oid - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE (%s) - AND c.relname = :table_name AND c.relkind in ('r','v') - ) AND a.attnum > 0 AND NOT a.attisdropped - ORDER BY a.attnum - """ % schema_where_clause - - s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) - c = connection.execute(s, table_name=table.name, - schema=table.schema) - rows = c.fetchall() - - if not rows: - raise exceptions.NoSuchTableError(table.name) - - domains = self._load_domains(connection) - - for name, format_type, default, notnull, attnum, table_oid in rows: - if include_columns and name not in include_columns: - continue - - ## strip (30) from character varying(30) - attype = re.search('([^\([]+)', format_type).group(1) - nullable = not notnull - is_array = format_type.endswith('[]') - - try: - charlen = re.search('\(([\d,]+)\)', format_type).group(1) - except: - charlen = False - - numericprec = False - numericscale = False - if attype == 'numeric': - if charlen is False: - numericprec, numericscale = (None, None) - else: - numericprec, numericscale = charlen.split(',') - charlen = False - if attype == 'double precision': - numericprec, numericscale = (53, False) - charlen = False - if attype == 'integer': - numericprec, numericscale = (32, 0) - charlen = False - - args = [] - for a in (charlen, numericprec, numericscale): - if a is None: - args.append(None) - elif a is not False: - args.append(int(a)) - - kwargs = {} - if attype == 'timestamp with time zone': - kwargs['timezone'] = True - elif attype == 'timestamp without time zone': - kwargs['timezone'] = False - - if attype in ischema_names: - coltype = ischema_names[attype] - else: - if attype in domains: - domain = domains[attype] - if domain['attype'] in ischema_names: - # A table can't override whether the domain is nullable. - nullable = domain['nullable'] - - if domain['default'] and not default: - # It can, however, override the default value, but can't set it to null. - default = domain['default'] - coltype = ischema_names[domain['attype']] - else: - coltype=None + schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" + schemaname = None + + SQL_COLS = """ + SELECT a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) + AS DEFAULT, + a.attnotnull, a.attnum, a.attrelid as table_oid + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = ( + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE (%s) + AND c.relname = :table_name AND c.relkind in ('r','v') + ) AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + """ % schema_where_clause + + s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}) + tablename = table.name + if isinstance(tablename, str): + tablename = tablename.decode(self.encoding) + c = connection.execute(s, table_name=tablename, schema=schemaname) + rows = c.fetchall() + + if not rows: + raise exceptions.NoSuchTableError(table.name) + + domains = self._load_domains(connection) + + for name, format_type, default, notnull, attnum, table_oid in rows: + if include_columns and name not in include_columns: + continue - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = PGArray(coltype) + ## strip (30) from character varying(30) + attype = re.search('([^\([]+)', format_type).group(1) + nullable = not notnull + is_array = format_type.endswith('[]') + + try: + charlen = re.search('\(([\d,]+)\)', format_type).group(1) + except: + charlen = False + + numericprec = False + numericscale = False + if attype == 'numeric': + if charlen is False: + numericprec, numericscale = (None, None) else: - warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name))) - coltype = sqltypes.NULLTYPE - - colargs= [] - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - # the default is related to a Sequence - sch = table.schema - if '.' not in match.group(2) and sch is not None: - default = match.group(1) + sch + '.' + match.group(2) + match.group(3) - colargs.append(schema.PassiveDefault(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) - - - # Primary keys - PK_SQL = """ - SELECT attname FROM pg_attribute - WHERE attrelid = ( - SELECT indexrelid FROM pg_index i - WHERE i.indrelid = :table - AND i.indisprimary = 't') - ORDER BY attnum - """ - t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) - for row in c.fetchall(): - pk = row[0] - table.primary_key.add(table.c[pk]) - - # Foreign keys - FK_SQL = """ - SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef - FROM pg_catalog.pg_constraint r - WHERE r.conrelid = :table AND r.contype = 'f' - ORDER BY 1 - """ - - t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) - for conname, condef in c.fetchall(): - m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() - (constrained_columns, referred_schema, referred_table, referred_columns) = m - constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] - if referred_schema: - referred_schema = preparer._unquote_identifier(referred_schema) - referred_table = preparer._unquote_identifier(referred_table) - referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] - - refspec = [] - if referred_schema is not None: - schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, - autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_schema, referred_table, column])) + numericprec, numericscale = charlen.split(',') + charlen = False + if attype == 'double precision': + numericprec, numericscale = (53, False) + charlen = False + if attype == 'integer': + numericprec, numericscale = (32, 0) + charlen = False + + args = [] + for a in (charlen, numericprec, numericscale): + if a is None: + args.append(None) + elif a is not False: + args.append(int(a)) + + kwargs = {} + if attype == 'timestamp with time zone': + kwargs['timezone'] = True + elif attype == 'timestamp without time zone': + kwargs['timezone'] = False + + if attype in ischema_names: + coltype = ischema_names[attype] + else: + if attype in domains: + domain = domains[attype] + if domain['attype'] in ischema_names: + # A table can't override whether the domain is nullable. + nullable = domain['nullable'] + + if domain['default'] and not default: + # It can, however, override the default value, but can't set it to null. + default = domain['default'] + coltype = ischema_names[domain['attype']] else: - schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_table, column])) + coltype=None + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = PGArray(coltype) + else: + util.warn("Did not recognize type '%s' of column '%s'" % + (attype, name)) + coltype = sqltypes.NULLTYPE + + colargs= [] + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + # the default is related to a Sequence + sch = table.schema + if '.' not in match.group(2) and sch is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / "quote schema" + default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) + colargs.append(schema.PassiveDefault(sql.text(default))) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) + + + # Primary keys + PK_SQL = """ + SELECT attname FROM pg_attribute + WHERE attrelid = ( + SELECT indexrelid FROM pg_index i + WHERE i.indrelid = :table + AND i.indisprimary = 't') + ORDER BY attnum + """ + t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) + c = connection.execute(t, table=table_oid) + for row in c.fetchall(): + pk = row[0] + col = table.c[pk] + table.primary_key.add(col) + if col.default is None: + col.autoincrement=False + + # Foreign keys + FK_SQL = """ + SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = :table AND r.contype = 'f' + ORDER BY 1 + """ + + t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) + c = connection.execute(t, table=table_oid) + for conname, condef in c.fetchall(): + m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() + (constrained_columns, referred_schema, referred_table, referred_columns) = m + constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] + if referred_schema: + referred_schema = preparer._unquote_identifier(referred_schema) + elif table.schema is not None and table.schema == self.get_default_schema_name(connection): + # no schema (i.e. its the default schema), and the table we're + # reflecting has the default schema explicit, then use that. + # i.e. try to use the user's conventions + referred_schema = table.schema + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] + + refspec = [] + if referred_schema is not None: + schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, + autoload_with=connection) + for column in referred_columns: + refspec.append(".".join([referred_schema, referred_table, column])) + else: + schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + + table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname)) - table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname)) - def _load_domains(self, connection): ## Load data types for domains: SQL_DOMAINS = """ @@ -524,7 +622,7 @@ class PGDialect(ansisql.ANSIDialect): ## strip (30) from character varying(30) attype = re.search('([^\(]+)', domain['attype']).group(1) if domain['visible']: - # 'visible' just means whether or not the domain is in a + # 'visible' just means whether or not the domain is in a # schema that's on the search path -- or not overriden by # a schema with higher presedence. If it's not visible, # it will be prefixed with the schema-name when it's used. @@ -535,19 +633,31 @@ class PGDialect(ansisql.ANSIDialect): domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']} return domains - - - -class PGCompiler(ansisql.ANSICompiler): - operators = ansisql.ANSICompiler.operators.copy() + + + +class PGCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() operators.update( { - operator.mod : '%%' + sql_operators.mod : '%%', + sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), + sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), } ) - def uses_sequences_for_inserts(self): - return True + functions = compiler.DefaultCompiler.functions.copy() + functions.update ( + { + 'TIMESTAMP':lambda x:'TIMESTAMP %s' % x, + } + ) + + def visit_sequence(self, seq): + if seq.optional: + return None + else: + return "nextval('%s')" % self.preparer.format_sequence(seq) def limit_clause(self, select): text = "" @@ -561,15 +671,14 @@ class PGCompiler(ansisql.ANSICompiler): def get_select_precolumns(self, select): if select._distinct: - if type(select._distinct) == bool: + if isinstance(select._distinct, bool): return "DISTINCT " - if type(select._distinct) == list: - dist_set = "DISTINCT ON (" - for col in select._distinct: - dist_set += self.strings[col] + ", " - dist_set = dist_set[:-2] + ") " - return dist_set - return "DISTINCT ON (" + str(select._distinct) + ") " + elif isinstance(select._distinct, (list, tuple)): + return "DISTINCT ON (" + ', '.join( + [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] + )+ ") " + else: + return "DISTINCT ON (" + unicode(select._distinct) + ") " else: return "" @@ -579,7 +688,34 @@ class PGCompiler(ansisql.ANSICompiler): else: return super(PGCompiler, self).for_update_clause(select) -class PGSchemaGenerator(ansisql.ANSISchemaGenerator): + def _append_returning(self, text, stmt): + returning_cols = stmt.kwargs['postgres_returning'] + def flatten_columnlist(collist): + for c in collist: + if isinstance(c, expression.Selectable): + for co in c.columns: + yield co + else: + yield c + columns = [self.process(c) for c in flatten_columnlist(returning_cols)] + text += ' RETURNING ' + string.join(columns, ', ') + return text + + def visit_update(self, update_stmt): + text = super(PGCompiler, self).visit_update(update_stmt) + if 'postgres_returning' in update_stmt.kwargs: + return self._append_returning(text, update_stmt) + else: + return text + + def visit_insert(self, insert_stmt): + text = super(PGCompiler, self).visit_insert(insert_stmt) + if 'postgres_returning' in insert_stmt.kwargs: + return self._append_returning(text, insert_stmt) + else: + return text + +class PGSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): @@ -588,7 +724,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): else: colspec += " SERIAL" else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -602,18 +738,41 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class PGSchemaDropper(ansisql.ANSISchemaDropper): + def visit_index(self, index): + preparer = self.preparer + self.append("CREATE ") + if index.unique: + self.append("UNIQUE ") + self.append("INDEX %s ON %s (%s)" \ + % (preparer.format_index(index), + preparer.format_table(index.table), + string.join([preparer.format_column(c) for c in index.columns], ', '))) + whereclause = index.kwargs.get('postgres_where', None) + if whereclause is not None: + compiler = self._compile(whereclause, None) + # this might belong to the compiler class + inlined_clause = str(compiler) % dict( + [(key,bind.value) for key,bind in compiler.binds.iteritems()]) + self.append(" WHERE " + inlined_clause) + self.execute() + +class PGSchemaDropper(compiler.SchemaDropper): def visit_sequence(self, sequence): if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): - self.append("DROP SEQUENCE %s" % sequence.name) + self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() -class PGDefaultRunner(ansisql.ANSIDefaultRunner): +class PGDefaultRunner(base.DefaultRunner): + def __init__(self, context): + base.DefaultRunner.__init__(self, context) + # craete cursor which won't conflict with a server-side cursor + self.cursor = context._connection.connection.cursor() + def get_column_default(self, column, isinsert=True): if column.primary_key: - # passive defaults on primary keys have to be overridden + # pre-execute passive defaults on primary keys if isinstance(column.default, schema.PassiveDefault): - return self.connection.execute("select %s" % column.default.arg).scalar() + return self.execute_string("select %s" % column.default.arg) elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting @@ -622,23 +781,25 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - return self.connection.execute(exc).scalar() + return self.execute_string(exc.encode(self.dialect.encoding)) - return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) + return super(PGDefaultRunner, self).get_column_default(column) def visit_sequence(self, seq): if not seq.optional: - return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar() + return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq))) else: return None -class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): - def _fold_identifier_case(self, value): - return value.lower() - +class PGIdentifierPreparer(compiler.IdentifierPreparer): def _unquote_identifier(self, value): if value[0] == self.initial_quote: value = value[1:-1].replace('""','"') return value dialect = PGDialect +dialect.statement_compiler = PGCompiler +dialect.schemagenerator = PGSchemaGenerator +dialect.schemadropper = PGSchemaDropper +dialect.preparer = PGIdentifierPreparer +dialect.defaultrunner = PGDefaultRunner diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 725ea23e2f..f8bea90ebc 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -1,20 +1,31 @@ # sqlite.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import re +import datetime, re, time -from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault -import sqlalchemy.engine.default as default +from sqlalchemy import schema, exceptions, pool, PassiveDefault +from sqlalchemy.engine import default import sqlalchemy.types as sqltypes -import datetime,time, warnings import sqlalchemy.util as util +from sqlalchemy.sql import compiler, functions as sql_functions + + +SELECT_REGEXP = re.compile(r'\s*(?:SELECT|PRAGMA)', re.I | re.UNICODE) - class SLNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + type_ = self.asdecimal and str or float + def process(value): + if value is not None: + return type_(value) + else: + return value + return process + def get_col_spec(self): if self.precision is None: return "NUMERIC" @@ -30,14 +41,21 @@ class SLSmallInteger(sqltypes.Smallinteger): return "SMALLINT" class DateTimeMixin(object): - def convert_bind_param(self, value, dialect): - if value is not None: - if getattr(value, 'microsecond', None) is not None: - return value.strftime(self.__format__ + "." + str(value.microsecond)) + __format__ = "%Y-%m-%d %H:%M:%S" + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, basestring): + # pass string values thru + return value + elif value is not None: + if self.__microsecond__ and getattr(value, 'microsecond', None) is not None: + return value.strftime(self.__format__ + "." + str(value.microsecond)) + else: + return value.strftime(self.__format__) else: - return value.strftime(self.__format__) - else: - return None + return None + return process def _cvt(self, value, dialect): if value is None: @@ -46,40 +64,49 @@ class DateTimeMixin(object): (value, microsecond) = value.split('.') microsecond = int(microsecond) except ValueError: - (value, microsecond) = (value, 0) + microsecond = 0 return time.strptime(value, self.__format__)[0:6] + (microsecond,) class SLDateTime(DateTimeMixin,sqltypes.DateTime): __format__ = "%Y-%m-%d %H:%M:%S" - + __microsecond__ = True + def get_col_spec(self): return "TIMESTAMP" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.datetime(*tup) + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.datetime(*tup) + return process class SLDate(DateTimeMixin, sqltypes.Date): __format__ = "%Y-%m-%d" + __microsecond__ = False def get_col_spec(self): return "DATE" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.date(*tup[0:3]) + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.date(*tup[0:3]) + return process class SLTime(DateTimeMixin, sqltypes.Time): __format__ = "%H:%M:%S" + __microsecond__ = True def get_col_spec(self): return "TIME" - def convert_result_value(self, value, dialect): - tup = self._cvt(value, dialect) - return tup and datetime.time(*tup[3:7]) + def result_processor(self, dialect): + def process(value): + tup = self._cvt(value, dialect) + return tup and datetime.time(*tup[3:7]) + return process -class SLText(sqltypes.TEXT): +class SLText(sqltypes.Text): def get_col_spec(self): return "TEXT" @@ -99,44 +126,54 @@ class SLBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" - def convert_bind_param(self, value, dialect): - if value is None: - return None - return value and 1 or 0 + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value and 1 or 0 + return process - def convert_result_value(self, value, dialect): - if value is None: - return None - return value and True or False + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process colspecs = { - sqltypes.Integer : SLInteger, - sqltypes.Smallinteger : SLSmallInteger, - sqltypes.Numeric : SLNumeric, - sqltypes.Float : SLNumeric, - sqltypes.DateTime : SLDateTime, - sqltypes.Date : SLDate, - sqltypes.Time : SLTime, - sqltypes.String : SLString, - sqltypes.Binary : SLBinary, - sqltypes.Boolean : SLBoolean, - sqltypes.TEXT : SLText, + sqltypes.Binary: SLBinary, + sqltypes.Boolean: SLBoolean, sqltypes.CHAR: SLChar, + sqltypes.Date: SLDate, + sqltypes.DateTime: SLDateTime, + sqltypes.Float: SLNumeric, + sqltypes.Integer: SLInteger, + sqltypes.NCHAR: SLChar, + sqltypes.Numeric: SLNumeric, + sqltypes.Smallinteger: SLSmallInteger, + sqltypes.String: SLString, + sqltypes.Text: SLText, + sqltypes.Time: SLTime, } -pragma_names = { - 'INTEGER' : SLInteger, - 'INT' : SLInteger, - 'SMALLINT' : SLSmallInteger, - 'VARCHAR' : SLString, - 'CHAR' : SLChar, - 'TEXT' : SLText, - 'NUMERIC' : SLNumeric, - 'FLOAT' : SLNumeric, - 'TIMESTAMP' : SLDateTime, - 'DATETIME' : SLDateTime, - 'DATE' : SLDate, - 'BLOB' : SLBinary, +ischema_names = { + 'BLOB': SLBinary, + 'BOOL': SLBoolean, + 'BOOLEAN': SLBoolean, + 'CHAR': SLChar, + 'DATE': SLDate, + 'DATETIME': SLDateTime, + 'DECIMAL': SLNumeric, + 'FLOAT': SLNumeric, + 'INT': SLInteger, + 'INTEGER': SLInteger, + 'NUMERIC': SLNumeric, + 'REAL': SLNumeric, + 'SMALLINT': SLSmallInteger, + 'TEXT': SLText, + 'TIME': SLTime, + 'TIMESTAMP': SLDateTime, + 'VARCHAR': SLString, } def descriptor(): @@ -148,27 +185,32 @@ def descriptor(): class SQLiteExecutionContext(default.DefaultExecutionContext): def post_exec(self): - if self.compiled.isinsert: + if self.compiled.isinsert and not self.executemany: if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - def is_select(self): - return re.match(r'SELECT|PRAGMA', self.statement.lstrip(), re.I) is not None - -class SQLiteDialect(ansisql.ANSIDialect): - + def returns_rows_text(self, statement): + return SELECT_REGEXP.match(statement) + +class SQLiteDialect(default.DefaultDialect): + supports_alter = False + supports_unicode_statements = True + default_paramstyle = 'qmark' + def __init__(self, **kwargs): - ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs) + default.DefaultDialect.__init__(self, **kwargs) def vers(num): return tuple([int(x) for x in num.split('.')]) if self.dbapi is not None: sqlite_ver = self.dbapi.version_info if sqlite_ver < (2,1,'3'): - warnings.warn(RuntimeWarning("The installed version of pysqlite2 (%s) is out-dated, and will cause errors in some cases. Version 2.1.3 or greater is recommended." % '.'.join([str(subver) for subver in sqlite_ver]))) - if vers(self.dbapi.sqlite_version) < vers("3.3.13"): - warnings.warn(RuntimeWarning("The installed version of sqlite (%s) is out-dated, and will cause errors in some cases. Version 3.3.13 or greater is recommended." % self.dbapi.sqlite_version)) + util.warn( + ("The installed version of pysqlite2 (%s) is out-dated " + "and will cause errors in some cases. Version 2.1.3 " + "or greater is recommended.") % + '.'.join([str(subver) for subver in sqlite_ver])) self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) - + def dbapi(cls): try: from pysqlite2 import dbapi2 as sqlite @@ -176,29 +218,21 @@ class SQLiteDialect(ansisql.ANSIDialect): try: from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. except ImportError: - try: - sqlite = __import__('sqlite') # skip ourselves - except ImportError: - raise e + raise e return sqlite dbapi = classmethod(dbapi) - def compiler(self, statement, bindparams, **kwargs): - return SQLiteCompiler(self, statement, bindparams, **kwargs) - - def schemagenerator(self, *args, **kwargs): - return SQLiteSchemaGenerator(self, *args, **kwargs) - - def schemadropper(self, *args, **kwargs): - return SQLiteSchemaDropper(self, *args, **kwargs) - - def supports_alter(self): - return False - - def preparer(self): - return SQLiteIdentifierPreparer(self) + def server_version_info(self, connection): + return self.dbapi.sqlite_version_info def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exceptions.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,)) filename = url.database or ':memory:' opts = url.query.copy() @@ -213,42 +247,75 @@ class SQLiteDialect(ansisql.ANSIDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def create_execution_context(self, **kwargs): - return SQLiteExecutionContext(self, **kwargs) - - def supports_unicode_statements(self): - return True - - def last_inserted_ids(self): - return self.context.last_inserted_ids + def create_execution_context(self, connection, **kwargs): + return SQLiteExecutionContext(self, connection, **kwargs) def oid_column_name(self, column): return "oid" + def is_disconnect(self, e): + return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) + + def table_names(self, connection, schema): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='table' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + except exceptions.DBAPIError: + raise + s = ("SELECT name FROM sqlite_master " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + def has_table(self, connection, table_name, schema=None): - cursor = connection.execute("PRAGMA table_info(%s)" % - self.identifier_preparer.quote_identifier(table_name), {}) + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + cursor = connection.execute("%stable_info(%s)" % (pragma, qtable)) row = cursor.fetchone() - # consume remaining rows, to work around: http://www.sqlite.org/cvstrac/tktview?tn=1884 - while cursor.fetchone() is not None:pass + # consume remaining rows, to work around + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + while cursor.fetchone() is not None: + pass return (row is not None) def reflecttable(self, connection, table, include_columns): - c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {}) + preparer = self.identifier_preparer + if table.schema is None: + pragma = "PRAGMA " + else: + pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema) + qtable = preparer.format_table(table, False) + + c = connection.execute("%stable_info(%s)" % (pragma, qtable)) found_table = False while True: row = c.fetchone() if row is None: break - #print "row! " + repr(row) + found_table = True - (name, type, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5]) + (name, type_, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5]) name = re.sub(r'^\"|\"$', '', name) if include_columns and name not in include_columns: continue - match = re.match(r'(\w+)(\(.*?\))?', type) + match = re.match(r'(\w+)(\(.*?\))?', type_) if match: coltype = match.group(1) args = match.group(2) @@ -256,16 +323,15 @@ class SQLiteDialect(ansisql.ANSIDialect): coltype = "VARCHAR" args = '' - #print "coltype: " + repr(coltype) + " args: " + repr(args) try: - coltype = pragma_names[coltype] + coltype = ischema_names[coltype] except KeyError: - warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name))) - coltype = sqltypes.NULLTYPE - + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, name)) + coltype = sqltypes.NullType + if args is not None: args = re.findall(r'(\d+)', args) - #print "args! " +repr(args) coltype = coltype(*[int(a) for a in args]) colargs= [] @@ -276,7 +342,7 @@ class SQLiteDialect(ansisql.ANSIDialect): if not found_table: raise exceptions.NoSuchTableError(table.name) - c = connection.execute("PRAGMA foreign_key_list(%s)" % self.preparer().format_table(table), {}) + c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)) fks = {} while True: row = c.fetchone() @@ -292,7 +358,6 @@ class SQLiteDialect(ansisql.ANSIDialect): fk = ([],[]) fks[constraint_name] = fk - #print "row! " + repr([key for key in row.keys()]), repr(row) # look up the table based on the given table's engine, not 'self', # since it could be a ProxyEngine remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection) @@ -305,7 +370,7 @@ class SQLiteDialect(ansisql.ANSIDialect): for name, value in fks.iteritems(): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1])) # check for UNIQUE indexes - c = connection.execute("PRAGMA index_list(%s)" % self.preparer().format_table(table), {}) + c = connection.execute("%sindex_list(%s)" % (pragma, qtable)) unique_indexes = [] while True: row = c.fetchone() @@ -315,23 +380,27 @@ class SQLiteDialect(ansisql.ANSIDialect): unique_indexes.append(row[1]) # loop thru unique indexes for one that includes the primary key for idx in unique_indexes: - c = connection.execute("PRAGMA index_info(" + idx + ")", {}) + c = connection.execute("%sindex_info(%s)" % (pragma, idx)) cols = [] while True: row = c.fetchone() if row is None: break cols.append(row[2]) - col = table.columns[row[2]] -class SQLiteCompiler(ansisql.ANSICompiler): - def visit_cast(self, cast): + +class SQLiteCompiler(compiler.DefaultCompiler): + functions = compiler.DefaultCompiler.functions.copy() + functions.update ( + { + sql_functions.now: 'CURRENT_TIMESTAMP' + } + ) + + def visit_cast(self, cast, **kwargs): if self.dialect.supports_cast: return super(SQLiteCompiler, self).visit_cast(cast) else: - if len(self.select_stack): - # not sure if we want to set the typemap here... - self.typemap.setdefault("CAST", cast.type) return self.process(cast.clause) def limit_clause(self, select): @@ -350,10 +419,26 @@ class SQLiteCompiler(ansisql.ANSICompiler): # sqlite has no "FOR UPDATE" AFAICT return '' -class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): + def visit_insert(self, insert_stmt): + self.isinsert = True + colparams = self._get_colparams(insert_stmt) + preparer = self.preparer + + if not colparams: + return "INSERT INTO %s DEFAULT VALUES" % ( + (preparer.format_table(insert_stmt.table),)) + else: + return ("INSERT INTO %s (%s) VALUES (%s)" % + (preparer.format_table(insert_stmt.table), + ', '.join([preparer.format_column(c[0]) + for c in colparams]), + ', '.join([c[1] for c in colparams]))) + + +class SQLiteSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -372,12 +457,36 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): # else: # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) -class SQLiteSchemaDropper(ansisql.ANSISchemaDropper): +class SQLiteSchemaDropper(compiler.SchemaDropper): pass -class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = util.Set([ + 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', + 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', + 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', + 'conflict', 'constraint', 'create', 'cross', 'current_date', + 'current_time', 'current_timestamp', 'database', 'default', + 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', + 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', + 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', + 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', + 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is', + 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural', + 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer', + 'plan', 'pragma', 'primary', 'query', 'raise', 'references', + 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback', + 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', + 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', + 'vacuum', 'values', 'view', 'virtual', 'when', 'where', + ]) + def __init__(self, dialect): - super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True) + super(SQLiteIdentifierPreparer, self).__init__(dialect) dialect = SQLiteDialect dialect.poolclass = pool.SingletonThreadPool +dialect.statement_compiler = SQLiteCompiler +dialect.schemagenerator = SQLiteSchemaGenerator +dialect.schemadropper = SQLiteSchemaDropper +dialect.preparer = SQLiteIdentifierPreparer diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py new file mode 100644 index 0000000000..2551e90c53 --- /dev/null +++ b/lib/sqlalchemy/databases/sybase.py @@ -0,0 +1,876 @@ +# sybase.py +# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch +# Coding: Alexander Houben alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Sybase database backend. + +Known issues / TODO: + + * Uses the mx.ODBC driver from egenix (version 2.1.0) + * The current version of sqlalchemy.databases.sybase only supports + mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need + some development) + * Support for pyodbc has been built in but is not yet complete (needs + further development) + * Results of running tests/alltests.py: + Ran 934 tests in 287.032s + FAILED (failures=3, errors=1) + * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) +""" + +import datetime, operator + +from sqlalchemy import util, sql, schema, exceptions +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import MetaData, Table, Column +from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey + + +__all__ = [ + 'SybaseTypeError' + 'SybaseNumeric', 'SybaseFloat', 'SybaseInteger', 'SybaseBigInteger', + 'SybaseTinyInteger', 'SybaseSmallInteger', + 'SybaseDateTime_mxodbc', 'SybaseDateTime_pyodbc', + 'SybaseDate_mxodbc', 'SybaseDate_pyodbc', + 'SybaseTime_mxodbc', 'SybaseTime_pyodbc', + 'SybaseText', 'SybaseString', 'SybaseChar', 'SybaseBinary', + 'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney', + 'SybaseUniqueIdentifier', + ] + + +RESERVED_WORDS = util.Set([ + "add", "all", "alter", "and", + "any", "as", "asc", "backup", + "begin", "between", "bigint", "binary", + "bit", "bottom", "break", "by", + "call", "capability", "cascade", "case", + "cast", "char", "char_convert", "character", + "check", "checkpoint", "close", "comment", + "commit", "connect", "constraint", "contains", + "continue", "convert", "create", "cross", + "cube", "current", "current_timestamp", "current_user", + "cursor", "date", "dbspace", "deallocate", + "dec", "decimal", "declare", "default", + "delete", "deleting", "desc", "distinct", + "do", "double", "drop", "dynamic", + "else", "elseif", "encrypted", "end", + "endif", "escape", "except", "exception", + "exec", "execute", "existing", "exists", + "externlogin", "fetch", "first", "float", + "for", "force", "foreign", "forward", + "from", "full", "goto", "grant", + "group", "having", "holdlock", "identified", + "if", "in", "index", "index_lparen", + "inner", "inout", "insensitive", "insert", + "inserting", "install", "instead", "int", + "integer", "integrated", "intersect", "into", + "iq", "is", "isolation", "join", + "key", "lateral", "left", "like", + "lock", "login", "long", "match", + "membership", "message", "mode", "modify", + "natural", "new", "no", "noholdlock", + "not", "notify", "null", "numeric", + "of", "off", "on", "open", + "option", "options", "or", "order", + "others", "out", "outer", "over", + "passthrough", "precision", "prepare", "primary", + "print", "privileges", "proc", "procedure", + "publication", "raiserror", "readtext", "real", + "reference", "references", "release", "remote", + "remove", "rename", "reorganize", "resource", + "restore", "restrict", "return", "revoke", + "right", "rollback", "rollup", "save", + "savepoint", "scroll", "select", "sensitive", + "session", "set", "setuser", "share", + "smallint", "some", "sqlcode", "sqlstate", + "start", "stop", "subtrans", "subtransaction", + "synchronize", "syntax_error", "table", "temporary", + "then", "time", "timestamp", "tinyint", + "to", "top", "tran", "trigger", + "truncate", "tsequal", "unbounded", "union", + "unique", "unknown", "unsigned", "update", + "updating", "user", "using", "validate", + "values", "varbinary", "varchar", "variable", + "varying", "view", "wait", "waitfor", + "when", "where", "while", "window", + "with", "with_cube", "with_lparen", "with_rollup", + "within", "work", "writetext", + ]) + +ischema = MetaData() + +tables = Table("SYSTABLE", ischema, + Column("table_id", Integer, primary_key=True), + Column("file_id", SMALLINT), + Column("table_name", CHAR(128)), + Column("table_type", CHAR(10)), + Column("creator", Integer), + #schema="information_schema" + ) + +domains = Table("SYSDOMAIN", ischema, + Column("domain_id", Integer, primary_key=True), + Column("domain_name", CHAR(128)), + Column("type_id", SMALLINT), + Column("precision", SMALLINT, quote=True), + #schema="information_schema" + ) + +columns = Table("SYSCOLUMN", ischema, + Column("column_id", Integer, primary_key=True), + Column("table_id", Integer, ForeignKey(tables.c.table_id)), + Column("pkey", CHAR(1)), + Column("column_name", CHAR(128)), + Column("nulls", CHAR(1)), + Column("width", SMALLINT), + Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)), + # FIXME: should be mx.BIGINT + Column("max_identity", Integer), + # FIXME: should be mx.ODBC.Windows.LONGVARCHAR + Column("default", String), + Column("scale", Integer), + #schema="information_schema" + ) + +foreignkeys = Table("SYSFOREIGNKEY", ischema, + Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, primary_key=True), + Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)), + #schema="information_schema" + ) +fkcols = Table("SYSFKCOL", ischema, + Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True), + Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True), + Column("primary_column_id", Integer), + #schema="information_schema" + ) + +class SybaseTypeError(sqltypes.TypeEngine): + def result_processor(self, dialect): + return None + + def bind_processor(self, dialect): + def process(value): + raise exceptions.NotSupportedError("Data type not supported", [value]) + return process + + def get_col_spec(self): + raise exceptions.NotSupportedError("Data type not supported") + +class SybaseNumeric(sqltypes.Numeric): + def get_col_spec(self): + if self.length is None: + if self.precision is None: + return "NUMERIC" + else: + return "NUMERIC(%(precision)s)" % {'precision' : self.precision} + else: + return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + +class SybaseFloat(sqltypes.FLOAT, SybaseNumeric): + def __init__(self, precision = 10, asdecimal = False, length = 2, **kwargs): + super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs) + self.length = length + + def get_col_spec(self): + # if asdecimal is True, handle same way as SybaseNumeric + if self.asdecimal: + return SybaseNumeric.get_col_spec(self) + if self.precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': self.precision} + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return float(value) + if self.asdecimal: + return SybaseNumeric.result_processor(self, dialect) + return process + +class SybaseInteger(sqltypes.Integer): + def get_col_spec(self): + return "INTEGER" + +class SybaseBigInteger(SybaseInteger): + def get_col_spec(self): + return "BIGINT" + +class SybaseTinyInteger(SybaseInteger): + def get_col_spec(self): + return "TINYINT" + +class SybaseSmallInteger(SybaseInteger): + def get_col_spec(self): + return "SMALLINT" + +class SybaseDateTime_mxodbc(sqltypes.DateTime): + def __init__(self, *a, **kw): + super(SybaseDateTime_mxodbc, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class SybaseDateTime_pyodbc(sqltypes.DateTime): + def __init__(self, *a, **kw): + super(SybaseDateTime_pyodbc, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + # Convert the datetime.datetime back to datetime.time + return value + return process + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value + return process + +class SybaseDate_mxodbc(sqltypes.Date): + def __init__(self, *a, **kw): + super(SybaseDate_mxodbc, self).__init__(False) + + def get_col_spec(self): + return "DATE" + +class SybaseDate_pyodbc(sqltypes.Date): + def __init__(self, *a, **kw): + super(SybaseDate_pyodbc, self).__init__(False) + + def get_col_spec(self): + return "DATE" + +class SybaseTime_mxodbc(sqltypes.Time): + def __init__(self, *a, **kw): + super(SybaseTime_mxodbc, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + # Convert the datetime.datetime back to datetime.time + return datetime.time(value.hour, value.minute, value.second, value.microsecond) + return process + +class SybaseTime_pyodbc(sqltypes.Time): + def __init__(self, *a, **kw): + super(SybaseTime_pyodbc, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + # Convert the datetime.datetime back to datetime.time + return datetime.time(value.hour, value.minute, value.second, value.microsecond) + return process + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return datetime.datetime(1970, 1, 1, value.hour, value.minute, value.second, value.microsecond) + return process + +class SybaseText(sqltypes.Text): + def get_col_spec(self): + return "TEXT" + +class SybaseString(sqltypes.String): + def get_col_spec(self): + return "VARCHAR(%(length)s)" % {'length' : self.length} + +class SybaseChar(sqltypes.CHAR): + def get_col_spec(self): + return "CHAR(%(length)s)" % {'length' : self.length} + +class SybaseBinary(sqltypes.Binary): + def get_col_spec(self): + return "IMAGE" + +class SybaseBoolean(sqltypes.Boolean): + def get_col_spec(self): + return "BIT" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class SybaseTimeStamp(sqltypes.TIMESTAMP): + def get_col_spec(self): + return "TIMESTAMP" + +class SybaseMoney(sqltypes.TypeEngine): + def get_col_spec(self): + return "MONEY" + +class SybaseSmallMoney(SybaseMoney): + def get_col_spec(self): + return "SMALLMONEY" + +class SybaseUniqueIdentifier(sqltypes.TypeEngine): + def get_col_spec(self): + return "UNIQUEIDENTIFIER" + +def descriptor(): + return {'name':'sybase', + 'description':'SybaseSQL', + 'arguments':[ + ('user',"Database Username",None), + ('password',"Database Password",None), + ('db',"Database Name",None), + ('host',"Hostname", None), + ]} + +class SybaseSQLExecutionContext(default.DefaultExecutionContext): + pass + +class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext): + + def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): + super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters) + + def pre_exec(self): + super(SybaseSQLExecutionContext_mxodbc, self).pre_exec() + + def post_exec(self): + if self.compiled.isinsert: + table = self.compiled.statement.table + # get the inserted values of the primary key + + # get any sequence IDs first (using @@identity) + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + lastrowid = int(row[0]) + if lastrowid > 0: + # an IDENTITY was inserted, fetch it + # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! + if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: + self._last_inserted_ids = [lastrowid] + else: + self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] + super(SybaseSQLExecutionContext_mxodbc, self).post_exec() + +class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext): + def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): + super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters) + + def pre_exec(self): + super(SybaseSQLExecutionContext_pyodbc, self).pre_exec() + + def post_exec(self): + if self.compiled.isinsert: + table = self.compiled.statement.table + # get the inserted values of the primary key + + # get any sequence IDs first (using @@identity) + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + lastrowid = int(row[0]) + if lastrowid > 0: + # an IDENTITY was inserted, fetch it + # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! + if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: + self._last_inserted_ids = [lastrowid] + else: + self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] + super(SybaseSQLExecutionContext_pyodbc, self).post_exec() + +class SybaseSQLDialect(default.DefaultDialect): + colspecs = { + # FIXME: unicode support + #sqltypes.Unicode : SybaseUnicode, + sqltypes.Integer : SybaseInteger, + sqltypes.SmallInteger : SybaseSmallInteger, + sqltypes.Numeric : SybaseNumeric, + sqltypes.Float : SybaseFloat, + sqltypes.String : SybaseString, + sqltypes.Binary : SybaseBinary, + sqltypes.Boolean : SybaseBoolean, + sqltypes.Text : SybaseText, + sqltypes.CHAR : SybaseChar, + sqltypes.TIMESTAMP : SybaseTimeStamp, + sqltypes.FLOAT : SybaseFloat, + } + + ischema_names = { + 'integer' : SybaseInteger, + 'unsigned int' : SybaseInteger, + 'unsigned smallint' : SybaseInteger, + 'unsigned bigint' : SybaseInteger, + 'bigint': SybaseBigInteger, + 'smallint' : SybaseSmallInteger, + 'tinyint' : SybaseTinyInteger, + 'varchar' : SybaseString, + 'long varchar' : SybaseText, + 'char' : SybaseChar, + 'decimal' : SybaseNumeric, + 'numeric' : SybaseNumeric, + 'float' : SybaseFloat, + 'double' : SybaseFloat, + 'binary' : SybaseBinary, + 'long binary' : SybaseBinary, + 'varbinary' : SybaseBinary, + 'bit': SybaseBoolean, + 'image' : SybaseBinary, + 'timestamp': SybaseTimeStamp, + 'money': SybaseMoney, + 'smallmoney': SybaseSmallMoney, + 'uniqueidentifier': SybaseUniqueIdentifier, + + 'java.lang.Object' : SybaseTypeError, + 'java serialization' : SybaseTypeError, + } + + # Sybase backend peculiarities + supports_unicode_statements = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + def __new__(cls, dbapi=None, *args, **kwargs): + if cls != SybaseSQLDialect: + return super(SybaseSQLDialect, cls).__new__(cls, *args, **kwargs) + if dbapi: + print dbapi.__name__ + dialect = dialect_mapping.get(dbapi.__name__) + return dialect(*args, **kwargs) + else: + return object.__new__(cls, *args, **kwargs) + + def __init__(self, **params): + super(SybaseSQLDialect, self).__init__(**params) + self.text_as_varchar = False + # FIXME: what is the default schema for sybase connections (DBA?) ? + self.set_default_schema_name("dba") + + def dbapi(cls, module_name=None): + if module_name: + try: + dialect_cls = dialect_mapping[module_name] + return dialect_cls.import_dbapi() + except KeyError: + raise exceptions.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name) + else: + for dialect_cls in dialect_mapping.values(): + try: + return dialect_cls.import_dbapi() + except ImportError, e: + pass + else: + raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc') + dbapi = classmethod(dbapi) + + def create_execution_context(self, *args, **kwargs): + return SybaseSQLExecutionContext(self, *args, **kwargs) + + def type_descriptor(self, typeobj): + newobj = sqltypes.adapt_type(typeobj, self.colspecs) + return newobj + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def get_default_schema_name(self, connection): + return self.schema_name + + def set_default_schema_name(self, schema_name): + self.schema_name = schema_name + + def do_execute(self, cursor, statement, params, **kwargs): + params = tuple(params) + super(SybaseSQLDialect, self).do_execute(cursor, statement, params, **kwargs) + + # FIXME: remove ? + def _execute(self, c, statement, parameters): + try: + if parameters == {}: + parameters = () + c.execute(statement, parameters) + self.context.rowcount = c.rowcount + c.DBPROP_COMMITPRESERVE = "Y" + except Exception, e: + raise exceptions.DBAPIError.instance(statement, parameters, e) + + def table_names(self, connection, schema): + """Ignore the schema and the charset for now.""" + s = sql.select([tables.c.table_name], + sql.not_(tables.c.table_name.like("SYS%")) and + tables.c.creator >= 100 + ) + rp = connection.execute(s) + return [row[0] for row in rp.fetchall()] + + def has_table(self, connection, tablename, schema=None): + # FIXME: ignore schemas for sybase + s = sql.select([tables.c.table_name], tables.c.table_name == tablename) + + c = connection.execute(s) + row = c.fetchone() + print "has_table: " + tablename + ": " + str(bool(row is not None)) + return row is not None + + def reflecttable(self, connection, table, include_columns): + # Get base columns + if table.schema is not None: + current_schema = table.schema + else: + current_schema = self.get_default_schema_name(connection) + + s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) + + c = connection.execute(s) + found_table = False + # makes sure we append the columns in the correct order + while True: + row = c.fetchone() + if row is None: + break + found_table = True + (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( + row[columns.c.column_name], + row[domains.c.domain_name], + row[columns.c.nulls] == 'Y', + row[columns.c.width], + row[domains.c.precision], + row[columns.c.scale], + row[columns.c.default], + row[columns.c.pkey] == 'Y', + row[columns.c.max_identity], + row[tables.c.table_id], + row[columns.c.column_id], + ) + if include_columns and name not in include_columns: + continue + + # FIXME: else problems with SybaseBinary(size) + if numericscale == 0: + numericscale = None + + args = [] + for a in (charlen, numericprec, numericscale): + if a is not None: + args.append(a) + coltype = self.ischema_names.get(type, None) + if coltype == SybaseString and charlen == -1: + coltype = SybaseText() + else: + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % + (type, name)) + coltype = sqltypes.NULLTYPE + coltype = coltype(*args) + colargs= [] + if default is not None: + colargs.append(schema.PassiveDefault(sql.text(default))) + + # any sequences ? + col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) + if int(max_identity) > 0: + col.sequence = schema.Sequence(name + '_identity') + col.sequence.start = int(max_identity) + col.sequence.increment = 1 + + # append the column + table.append_column(col) + + # any foreign key constraint for this table ? + # note: no multi-column foreign keys are considered + s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } + c = connection.execute(s) + foreignKeys = {} + while True: + row = c.fetchone() + if row is None: + break + (foreign_table, foreign_column, primary_table, primary_column) = ( + row[0], row[1], row[2], row[3], + ) + if not primary_table in foreignKeys.keys(): + foreignKeys[primary_table] = [['%s'%(foreign_column)], ['%s.%s'%(primary_table,primary_column)]] + else: + foreignKeys[primary_table][0].append('%s'%(foreign_column)) + foreignKeys[primary_table][1].append('%s.%s'%(primary_table,primary_column)) + for primary_table in foreignKeys.keys(): + #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) + table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1])) + + if not found_table: + raise exceptions.NoSuchTableError(table.name) + + +class SybaseSQLDialect_mxodbc(SybaseSQLDialect): + def __init__(self, **params): + super(SybaseSQLDialect_mxodbc, self).__init__(**params) + + self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()} + + def import_dbapi(cls): + #import mx.ODBC.Windows as module + import mxODBC as module + return module + import_dbapi = classmethod(import_dbapi) + + colspecs = SybaseSQLDialect.colspecs.copy() + colspecs[sqltypes.Time] = SybaseTime_mxodbc + colspecs[sqltypes.Date] = SybaseDate_mxodbc + colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc + + ischema_names = SybaseSQLDialect.ischema_names.copy() + ischema_names['time'] = SybaseTime_mxodbc + ischema_names['date'] = SybaseDate_mxodbc + ischema_names['datetime'] = SybaseDateTime_mxodbc + ischema_names['smalldatetime'] = SybaseDateTime_mxodbc + + def is_disconnect(self, e): + # FIXME: optimize + #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) + #return True + return False + + def create_execution_context(self, *args, **kwargs): + return SybaseSQLExecutionContext_mxodbc(self, *args, **kwargs) + + def do_execute(self, cursor, statement, parameters, context=None, **kwargs): + super(SybaseSQLDialect_mxodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) + + def create_connect_args(self, url): + '''Return a tuple of *args,**kwargs''' + # FIXME: handle mx.odbc.Windows proprietary args + opts = url.translate_connect_args(username='user') + opts.update(url.query) + argsDict = {} + argsDict['user'] = opts['user'] + argsDict['password'] = opts['password'] + connArgs = [[opts['dsn']], argsDict] + return connArgs + + +class SybaseSQLDialect_pyodbc(SybaseSQLDialect): + def __init__(self, **params): + super(SybaseSQLDialect_pyodbc, self).__init__(**params) + self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()} + + def import_dbapi(cls): + import mypyodbc as module + return module + import_dbapi = classmethod(import_dbapi) + + colspecs = SybaseSQLDialect.colspecs.copy() + colspecs[sqltypes.Time] = SybaseTime_pyodbc + colspecs[sqltypes.Date] = SybaseDate_pyodbc + colspecs[sqltypes.DateTime] = SybaseDateTime_pyodbc + + ischema_names = SybaseSQLDialect.ischema_names.copy() + ischema_names['time'] = SybaseTime_pyodbc + ischema_names['date'] = SybaseDate_pyodbc + ischema_names['datetime'] = SybaseDateTime_pyodbc + ischema_names['smalldatetime'] = SybaseDateTime_pyodbc + + def is_disconnect(self, e): + # FIXME: optimize + #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) + #return True + return False + + def create_execution_context(self, *args, **kwargs): + return SybaseSQLExecutionContext_pyodbc(self, *args, **kwargs) + + def do_execute(self, cursor, statement, parameters, context=None, **kwargs): + super(SybaseSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) + + def create_connect_args(self, url): + '''Return a tuple of *args,**kwargs''' + # FIXME: handle pyodbc proprietary args + opts = url.translate_connect_args(username='user') + opts.update(url.query) + + self.autocommit = False + if 'autocommit' in opts: + self.autocommit = bool(int(opts.pop('autocommit'))) + + argsDict = {} + argsDict['UID'] = opts['user'] + argsDict['PWD'] = opts['password'] + argsDict['DSN'] = opts['dsn'] + connArgs = [[';'.join(["%s=%s"%(key, argsDict[key]) for key in argsDict])], {'autocommit' : self.autocommit}] + return connArgs + + +dialect_mapping = { + 'sqlalchemy.databases.mxODBC' : SybaseSQLDialect_mxodbc, +# 'pyodbc' : SybaseSQLDialect_pyodbc, + } + + +class SybaseSQLCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() + operators.update({ + sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), + }) + + def bindparam_string(self, name): + res = super(SybaseSQLCompiler, self).bindparam_string(name) + if name.lower().startswith('literal'): + res = 'STRING(%s)'%res + return res + + def get_select_precolumns(self, select): + s = select._distinct and "DISTINCT " or "" + if select._limit: + #if select._limit == 1: + #s += "FIRST " + #else: + #s += "TOP %s " % (select._limit,) + s += "TOP %s " % (select._limit,) + if select._offset: + if not select._limit: + # FIXME: sybase doesn't allow an offset without a limit + # so use a huge value for TOP here + s += "TOP 1000000 " + s += "START AT %s " % (select._offset+1,) + return s + + def limit_clause(self, select): + # Limit in sybase is after the select keyword + return "" + + def visit_binary(self, binary): + """Move bind parameters to the right-hand side of an operator, where possible.""" + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(SybaseSQLCompiler, self).visit_binary(binary) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression._Function): + return column.label(None) + else: + return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) + + function_rewrites = {'current_date': 'getdate', + } + def visit_function(self, func): + func.name = self.function_rewrites.get(func.name, func.name) + res = super(SybaseSQLCompiler, self).visit_function(func) + if func.name.lower() == 'getdate': + # apply CAST operator + # FIXME: what about _pyodbc ? + cast = expression._Cast(func, SybaseDate_mxodbc) + # infinite recursion + # res = self.visit_cast(cast) + res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) + return res + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery(select) or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class SybaseSQLSchemaGenerator(compiler.SchemaGenerator): + def get_column_specification(self, column, **kwargs): + + colspec = self.preparer.format_column(column) + + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, sqltypes.Integer): + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + #colspec += " numeric(30,0) IDENTITY" + colspec += " Integer IDENTITY" + else: + colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + + if not column.nullable: + colspec += " NOT NULL" + + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + +class SybaseSQLSchemaDropper(compiler.SchemaDropper): + def visit_index(self, index): + self.append("\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(index.table.name), + self.preparer.quote_identifier(index.name) + )) + self.execute() + + +class SybaseSQLDefaultRunner(base.DefaultRunner): + pass + + +class SybaseSQLIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(SybaseSQLIdentifierPreparer, self).__init__(dialect) + + def _escape_identifier(self, value): + #TODO: determin SybaseSQL's escapeing rules + return value + + def _fold_identifier_case(self, value): + #TODO: determin SybaseSQL's case folding rules + return value + + +dialect = SybaseSQLDialect +dialect.statement_compiler = SybaseSQLCompiler +dialect.schemagenerator = SybaseSQLSchemaGenerator +dialect.schemadropper = SybaseSQLSchemaDropper +dialect.preparer = SybaseSQLIdentifierPreparer +dialect.defaultrunner = SybaseSQLDefaultRunner diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 50d03ea91d..eab8b3c0b4 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -1,80 +1,88 @@ # engine/__init__.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""defines the basic components used to interface DBAPI modules with -higher-level statement-construction, connection-management, -execution and result contexts. The primary "entry point" class into -this package is the Engine. - -The package is represented among several individual modules, including: - - base.py - Defines interface classes and some implementation classes - which comprise the basic components used to interface between - a DBAPI, constructed and plain-text statements, - connections, transactions, and results. - - default.py - Contains default implementations of some of the components - defined in base.py. All current database dialects use the - classes in default.py as base classes for their own database-specific - implementations. - - strategies.py - the mechanics of constructing ``Engine`` objects are represented here. - Defines the ``EngineStrategy`` class which represents how to go from - arguments specified to the ``create_engine()`` function, to a fully - constructed ``Engine``, including initialization of connection pooling, - dialects, and specific subclasses of ``Engine``. - - threadlocal.py - the ``TLEngine`` class is defined here, which is a subclass of the generic - ``Engine`` and tracks ``Connection`` and ``Transaction`` objects against - the identity of the current thread. This allows certain programming patterns - based around the concept of a "thread-local connection" to be possible. The - ``TLEngine`` is created by using the "threadlocal" engine strategy in - conjunction with the ``create_engine()`` function. - - url.py - Defines the ``URL`` class which represents the individual components of a - string URL passed to ``create_engine()``. Also defines a basic module-loading - strategy for the dialect specifier within a URL. - +"""SQL connections, SQL execution and high-level DB-API interface. + +The engine package defines the basic components used to interface +DB-API modules with higher-level statement construction, +connection-management, execution and result contexts. The primary +"entry point" class into this package is the Engine and it's public +constructor ``create_engine()``. + +This package includes: + +base.py + Defines interface classes and some implementation classes which + comprise the basic components used to interface between a DB-API, + constructed and plain-text statements, connections, transactions, + and results. + +default.py + Contains default implementations of some of the components defined + in base.py. All current database dialects use the classes in + default.py as base classes for their own database-specific + implementations. + +strategies.py + The mechanics of constructing ``Engine`` objects are represented + here. Defines the ``EngineStrategy`` class which represents how + to go from arguments specified to the ``create_engine()`` + function, to a fully constructed ``Engine``, including + initialization of connection pooling, dialects, and specific + subclasses of ``Engine``. + +threadlocal.py + The ``TLEngine`` class is defined here, which is a subclass of + the generic ``Engine`` and tracks ``Connection`` and + ``Transaction`` objects against the identity of the current + thread. This allows certain programming patterns based around + the concept of a "thread-local connection" to be possible. + The ``TLEngine`` is created by using the "threadlocal" engine + strategy in conjunction with the ``create_engine()`` function. + +url.py + Defines the ``URL`` class which represents the individual + components of a string URL passed to ``create_engine()``. Also + defines a basic module-loading strategy for the dialect specifier + within a URL. """ -from sqlalchemy import databases -from sqlalchemy.engine.base import * +import sqlalchemy.databases +from sqlalchemy.engine.base import Dialect, ExecutionContext, Compiled, \ + Connectable, Connection, Transaction, RootTransaction, \ + NestedTransaction, TwoPhaseTransaction, Engine, RowProxy, \ + BufferedColumnRow, ResultProxy, BufferedRowResultProxy, \ + BufferedColumnResultProxy, SchemaIterator, DefaultRunner from sqlalchemy.engine import strategies +from sqlalchemy import util -def engine_descriptors(): - """Provide a listing of all the database implementations supported. - - This data is provided as a list of dictionaries, where each - dictionary contains the following key/value pairs: - name - the name of the engine, suitable for use in the create_engine function +__all__ = [ + 'engine_descriptors', 'create_engine', 'engine_from_config', + 'Dialect', 'ExecutionContext', 'Compiled', 'Connectable', + 'Connection', 'Transaction', 'RootTransaction', 'NestedTransaction', + 'TwoPhaseTransaction', 'Engine', 'RowProxy', 'BufferedColumnRow', + 'ResultProxy', 'BufferedRowResultProxy', 'BufferedColumnResultProxy', + 'SchemaIterator', 'DefaultRunner', + ] - description - a plain description of the engine. - - arguments - a dictionary describing the name and description of each - parameter used to connect to this engine's underlying DBAPI. +def engine_descriptors(): + """Provide a listing of all the database implementations supported. + + This method will be removed in 0.5. - This function is meant for usage in automated configuration tools - that wish to query the user for database and connection - information. """ - result = [] for module in sqlalchemy.databases.__all__: - module = getattr(__import__('sqlalchemy.databases.%s' % module).databases, module) + module = getattr( + __import__('sqlalchemy.databases.%s' % module).databases, module) result.append(module.descriptor()) return result +engine_descriptors = util.deprecated()(engine_descriptors) + default_strategy = 'plain' def create_engine(*args, **kwargs): @@ -85,10 +93,11 @@ def create_engine(*args, **kwargs): dialect and connection arguments, with additional keyword arguments sent as options to the dialect and resulting Engine. - The URL is a string in the form - ``dialect://user:password@host/dbname[?key=value..]``, where + The URL is a string in the form + ``dialect://user:password@host/dbname[?key=value..]``, where ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgres``, - etc. Alternatively, the URL can be an instance of ``sqlalchemy.engine.url.URL``. + etc. Alternatively, the URL can be an instance of + ``sqlalchemy.engine.url.URL``. `**kwargs` represents options to be sent to the Engine itself as well as the components of the Engine, including the Dialect, the @@ -96,17 +105,17 @@ def create_engine(*args, **kwargs): follows: poolclass - a subclass of ``sqlalchemy.pool.Pool`` which will be used to + a subclass of ``sqlalchemy.pool.Pool`` which will be used to instantiate a connection pool. - + pool an instance of ``sqlalchemy.pool.DBProxy`` or ``sqlalchemy.pool.Pool`` to be used as the underlying source for - connections (DBProxy/Pool is described in the previous - section). This argument supercedes "poolclass". + connections (DBProxy/Pool is described in the previous section). + This argument supercedes "poolclass". echo - Defaults to False: if True, the Engine will log all statements + defaults to False: if True, the Engine will log all statements as well as a repr() of their parameter lists to the engines logger, which defaults to ``sys.stdout``. A Engine instances' `echo` data member can be modified at any time to turn logging @@ -114,40 +123,74 @@ def create_engine(*args, **kwargs): printed to the standard output as well. logger - Defaults to None: a file-like object where logging output can be + defaults to None: a file-like object where logging output can be sent, if `echo` is set to True. This defaults to ``sys.stdout``. encoding - Defaults to 'utf-8': the encoding to be used when + defaults to 'utf-8': the encoding to be used when encoding/decoding Unicode strings. convert_unicode - Defaults to False: true if unicode conversion should be applied + defaults to False: true if unicode conversion should be applied to all str types. module - Defaults to None: this is a - reference to a DBAPI2 module to be used instead of the engine's - default module. For Postgres, the default is psycopg2, or - psycopg1 if 2 cannot be found. For Oracle, its cx_Oracle. For - mysql, MySQLdb. + defaults to None: this is a reference to a DB-API 2.0 module to + be used instead of the dialect's default module. strategy - allows alternate Engine implementations to take effect. - Current implementations include ``plain`` and ``threadlocal``. - The default used by this function is ``plain``. - - ``plain`` provides support for a Connection object which can be used - to execute SQL queries with a specific underlying DBAPI connection. - - ``threadlocal`` is similar to ``plain`` except that it adds support - for a thread-local connection and transaction context, which - allows a group of engine operations to participate using the same - underlying connection and transaction without the need for explicitly - passing a single Connection. + allows alternate Engine implementations to take effect. Current + implementations include ``plain`` and ``threadlocal``. The + default used by this function is ``plain``. + + ``plain`` provides support for a Connection object which can be + used to execute SQL queries with a specific underlying DB-API + connection. + + ``threadlocal`` is similar to ``plain`` except that it adds + support for a thread-local connection and transaction context, + which allows a group of engine operations to participate using + the same underlying connection and transaction without the need + for explicitly passing a single Connection. """ strategy = kwargs.pop('strategy', default_strategy) strategy = strategies.strategies[strategy] return strategy.create(*args, **kwargs) + +def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): + """Create a new Engine instance using a configuration dictionary. + + The dictionary is typically produced from a config file where keys + are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The + 'prefix' argument indicates the prefix to be searched for. + + A select set of keyword arguments will be "coerced" to their + expected type based on string values. In a future release, this + functionality will be expanded and include dialect-specific + arguments. + """ + + opts = _coerce_config(configuration, prefix) + opts.update(kwargs) + url = opts.pop('url') + return create_engine(url, **opts) + +def _coerce_config(configuration, prefix): + """Convert configuration values to expected types.""" + + options = dict([(key[len(prefix):], configuration[key]) + for key in configuration if key.startswith(prefix)]) + for option, type_ in ( + ('convert_unicode', bool), + ('pool_timeout', int), + ('echo', bool), + ('echo_pool', bool), + ('pool_recycle', int), + ('pool_size', int), + ('max_overflow', int), + ('pool_threadlocal', bool), + ): + util.coerce_kw_type(options, option, type_) + return options diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index fc4433a47c..583a027638 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1,159 +1,160 @@ # engine/base.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""defines the basic components used to interface DBAPI modules with -higher-level statement-construction, connection-management, -execution and result contexts.""" -from sqlalchemy import exceptions, sql, schema, util, types, logging -import StringIO, sys, re, random +"""Basic components for SQL execution and interfacing with DB-API. + +Defines the basic components used to interface DB-API modules with +higher-level statement-construction, connection-management, execution +and result contexts. +""" + +import inspect, StringIO, sys +from sqlalchemy import exceptions, schema, util, types, logging +from sqlalchemy.sql import expression class Dialect(object): - """Define the behavior of a specific database/DBAPI. + """Define the behavior of a specific database and DB-API combination. - Any aspect of metadata definition, SQL query generation, execution, - result-set handling, or anything else which varies between - databases is defined under the general category of the Dialect. - The Dialect acts as a factory for other database-specific object - implementations including ExecutionContext, Compiled, - DefaultGenerator, and TypeEngine. + Any aspect of metadata definition, SQL query generation, + execution, result-set handling, or anything else which varies + between databases is defined under the general category of the + Dialect. The Dialect acts as a factory for other + database-specific object implementations including + ExecutionContext, Compiled, DefaultGenerator, and TypeEngine. All Dialects implement the following attributes: - positional - True if the paramstyle for this Dialect is positional + positional + True if the paramstyle for this Dialect is positional. - paramstyle - The paramstyle to be used (some DBAPIs support multiple paramstyles) + paramstyle + the paramstyle to be used (some DB-APIs support multiple + paramstyles). - convert_unicode - True if unicode conversion should be applied to all str types - - encoding - type of encoding to use for unicode, usually defaults to 'utf-8' - """ + convert_unicode + True if Unicode conversion should be applied to all ``str`` + types. - def create_connect_args(self, url): - """Build DBAPI compatible connection arguments. + encoding + type of encoding to use for unicode, usually defaults to + 'utf-8'. - Given a [sqlalchemy.engine.url#URL] object, returns a - tuple consisting of a `*args`/`**kwargs` suitable to send directly - to the dbapi's connect function. - """ + schemagenerator + a [sqlalchemy.schema#SchemaVisitor] class which generates + schemas. - raise NotImplementedError() + schemadropper + a [sqlalchemy.schema#SchemaVisitor] class which drops schemas. - def dbapi_type_map(self): - """return a mapping of DBAPI type objects present in this Dialect's DBAPI - mapped to TypeEngine implementations used by the dialect. - - This is used to apply types to result sets based on the DBAPI types - present in cursor.description; it only takes effect for result sets against - textual statements where no explicit typemap was present. Constructed SQL statements - always have type information explicitly embedded. - """ + defaultrunner + a [sqlalchemy.schema#SchemaVisitor] class which executes + defaults. - raise NotImplementedError() + statement_compiler + a [sqlalchemy.engine.base#Compiled] class used to compile SQL + statements - def type_descriptor(self, typeobj): - """Transform the given [sqlalchemy.types#TypeEngine] instance from generic to database-specific. + preparer + a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to + quote identifiers. - Subclasses will usually use the [sqlalchemy.types#adapt_type()] method in the types module - to make this job easy. - """ + supports_alter + ``True`` if the database supports ``ALTER TABLE``. - raise NotImplementedError() + max_identifier_length + The maximum length of identifier names. - def oid_column_name(self, column): - """Return the oid column name for this dialect, or ``None`` if the dialect can't/won't support OID/ROWID. + supports_unicode_statements + Indicate whether the DB-API can receive SQL statements as Python unicode strings - The [sqlalchemy.schema#Column] instance which represents OID for the query being - compiled is passed, so that the dialect can inspect the column - and its parent selectable to determine if OID/ROWID is not - selected for a particular selectable (i.e. oracle doesnt - support ROWID for UNION, GROUP BY, DISTINCT, etc.) - """ + supports_sane_rowcount + Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. - raise NotImplementedError() + supports_sane_multi_rowcount + Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements + when executed via executemany. - def supports_alter(self): - """return ``True`` if the database supports ``ALTER TABLE``.""" - raise NotImplementedError() + preexecute_pk_sequences + Indicate if the dialect should pre-execute sequences on primary key + columns during an INSERT, if it's desired that the new row's primary key + be available after execution. - def max_identifier_length(self): - """Return the maximum length of identifier names. - - Return ``None`` if no limit.""" - - return None + supports_pk_autoincrement + Indicates if the dialect should allow the database to passively assign + a primary key column value. - def supports_unicode_statements(self): - """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" - - raise NotImplementedError() - - def supports_sane_rowcount(self): - """Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. + dbapi_type_map + A mapping of DB-API type objects present in this Dialect's + DB-API implmentation mapped to TypeEngine implementations used + by the dialect. - This was needed for MySQL which had non-standard behavior of rowcount, - but this issue has since been resolved. - """ + This is used to apply types to result sets based on the DB-API + types present in cursor.description; it only takes effect for + result sets against textual statements where no explicit + typemap was present. - raise NotImplementedError() + """ - def schemagenerator(self, connection, **kwargs): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can generate schemas. + def create_connect_args(self, url): + """Build DB-API compatible connection arguments. - connection - a [sqlalchemy.engine#Connection] to use for statement execution - - `schemagenerator()` is called via the `create()` method on Table, - Index, and others. + Given a [sqlalchemy.engine.url#URL] object, returns a tuple + consisting of a `*args`/`**kwargs` suitable to send directly + to the dbapi's connect function. """ raise NotImplementedError() - def schemadropper(self, connection, **kwargs): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can drop schemas. - connection - a [sqlalchemy.engine#Connection] to use for statement execution + def type_descriptor(self, typeobj): + """Transform a generic type to a database-specific type. + + Transforms the given [sqlalchemy.types#TypeEngine] instance + from generic to database-specific. - `schemadropper()` is called via the `drop()` method on Table, - Index, and others. + Subclasses will usually use the + [sqlalchemy.types#adapt_type()] method in the types module to + make this job easy. """ raise NotImplementedError() - def defaultrunner(self, execution_context): - """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults. - - execution_context - a [sqlalchemy.engine#ExecutionContext] to use for statement execution - + def oid_column_name(self, column): + """Return the oid column name for this Dialect + + May return ``None`` if the dialect can't o won't support + OID/ROWID features. + + The [sqlalchemy.schema#Column] instance which represents OID + for the query being compiled is passed, so that the dialect + can inspect the column and its parent selectable to determine + if OID/ROWID is not selected for a particular selectable + (i.e. Oracle doesnt support ROWID for UNION, GROUP BY, + DISTINCT, etc.) """ raise NotImplementedError() - def compiler(self, statement, parameters): - """Return a [sqlalchemy.sql#Compiled] object for the given statement/parameters. - The returned object is usually a subclass of [sqlalchemy.ansisql#ANSICompiler]. - """ + def server_version_info(self, connection): + """Return a tuple of the database's version number.""" raise NotImplementedError() def reflecttable(self, connection, table, include_columns=None): """Load table description from the database. - Given a [sqlalchemy.engine#Connection] and a [sqlalchemy.schema#Table] object, reflect its - columns and properties from the database. If include_columns (a list or set) is specified, limit the autoload - to the given column names. + Given a [sqlalchemy.engine#Connection] and a + [sqlalchemy.schema#Table] object, reflect its columns and + properties from the database. If include_columns (a list or + set) is specified, limit the autoload to the given column + names. """ raise NotImplementedError() @@ -161,9 +162,10 @@ class Dialect(object): def has_table(self, connection, table_name, schema=None): """Check the existence of a particular table in the database. - Given a [sqlalchemy.engine#Connection] object and a string `table_name`, return True - if the given table (possibly within the specified `schema`) - exists in the database, False otherwise. + Given a [sqlalchemy.engine#Connection] object and a string + `table_name`, return True if the given table (possibly within + the specified `schema`) exists in the database, False + otherwise. """ raise NotImplementedError() @@ -171,9 +173,9 @@ class Dialect(object): def has_sequence(self, connection, sequence_name): """Check the existence of a particular sequence in the database. - Given a [sqlalchemy.engine#Connection] object and a string `sequence_name`, return - True if the given sequence exists in the database, False - otherwise. + Given a [sqlalchemy.engine#Connection] object and a string + `sequence_name`, return True if the given sequence exists in + the database, False otherwise. """ raise NotImplementedError() @@ -185,21 +187,31 @@ class Dialect(object): def create_execution_context(self, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None): """Return a new [sqlalchemy.engine#ExecutionContext] object.""" - + raise NotImplementedError() def do_begin(self, connection): - """Provide an implementation of *connection.begin()*, given a DBAPI connection.""" + """Provide an implementation of *connection.begin()*, given a DB-API connection.""" raise NotImplementedError() def do_rollback(self, connection): - """Provide an implementation of *connection.rollback()*, given a DBAPI connection.""" + """Provide an implementation of *connection.rollback()*, given a DB-API connection.""" + + raise NotImplementedError() + + def create_xid(self): + """Create a two-phase transaction ID. + + This id will be passed to do_begin_twophase(), + do_rollback_twophase(), do_commit_twophase(). Its format is + unspecified. + """ raise NotImplementedError() def do_commit(self, connection): - """Provide an implementation of *connection.commit()*, given a DBAPI connection.""" + """Provide an implementation of *connection.commit()*, given a DB-API connection.""" raise NotImplementedError() @@ -243,29 +255,19 @@ class Dialect(object): raise NotImplementedError() - def do_executemany(self, cursor, statement, parameters): + def do_executemany(self, cursor, statement, parameters, context=None): """Provide an implementation of *cursor.executemany(statement, parameters)*.""" raise NotImplementedError() - def do_execute(self, cursor, statement, parameters): + def do_execute(self, cursor, statement, parameters, context=None): """Provide an implementation of *cursor.execute(statement, parameters)*.""" raise NotImplementedError() - - def compile(self, clauseelement, parameters=None): - """Compile the given [sqlalchemy.sql#ClauseElement] using this Dialect. - - Returns [sqlalchemy.sql#Compiled]. A convenience method which - flips around the compile() call on ``ClauseElement``. - """ - - return clauseelement.compile(dialect=self, parameters=parameters) - def is_disconnect(self, e): - """Return True if the given DBAPI error indicates an invalid connection""" - + """Return True if the given DB-API error indicates an invalid connection""" + raise NotImplementedError() @@ -273,91 +275,121 @@ class ExecutionContext(object): """A messenger object for a Dialect that corresponds to a single execution. ExecutionContext should have these datamembers: - - connection - Connection object which initiated the call to the - dialect to create this ExecutionContext. - dialect - dialect which created this ExecutionContext. - - cursor - DBAPI cursor procured from the connection - - compiled - if passed to constructor, sql.Compiled object being executed - - statement - string version of the statement to be executed. Is either - passed to the constructor, or must be created from the - sql.Compiled object by the time pre_exec() has completed. - - parameters - bind parameters passed to the execute() method. for - compiled statements, this is a dictionary or list - of dictionaries. for textual statements, it should - be in a format suitable for the dialect's paramstyle - (i.e. dict or list of dicts for non positional, - list or list of lists/tuples for positional). - - + connection + Connection object which can be freely used by default value + generators to execute SQL. This Connection should reference the + same underlying connection/transactional resources of + root_connection. + + root_connection + Connection object which is the source of this ExecutionContext. This + Connection may have close_with_result=True set, in which case it can + only be used once. + + dialect + dialect which created this ExecutionContext. + + cursor + DB-API cursor procured from the connection, + + compiled + if passed to constructor, sqlalchemy.engine.base.Compiled object + being executed, + + statement + string version of the statement to be executed. Is either + passed to the constructor, or must be created from the + sql.Compiled object by the time pre_exec() has completed. + + parameters + bind parameters passed to the execute() method. For compiled + statements, this is a dictionary or list of dictionaries. For + textual statements, it should be in a format suitable for the + dialect's paramstyle (i.e. dict or list of dicts for non + positional, list or list of lists/tuples for positional). + + isinsert + True if the statement is an INSERT. + + isupdate + True if the statement is an UPDATE. + + should_autocommit + True if the statement is a "committable" statement + + returns_rows + True if the statement should return result rows + + postfetch_cols + a list of Column objects for which a server-side default + or inline SQL expression value was fired off. applies to inserts and updates. + The Dialect should provide an ExecutionContext via the create_execution_context() method. The `pre_exec` and `post_exec` methods will be called for compiled statements. - """ def create_cursor(self): """Return a new cursor generated from this ExecutionContext's connection. - - Some dialects may wish to change the behavior of connection.cursor(), - such as postgres which may return a PG "server side" cursor. + + Some dialects may wish to change the behavior of + connection.cursor(), such as postgres which may return a PG + "server side" cursor. """ raise NotImplementedError() def pre_execution(self): """Called before an execution of a compiled statement. - - If a compiled statement was passed to this - ExecutionContext, the `statement` and `parameters` datamembers - must be initialized after this statement is complete. + + If a compiled statement was passed to this ExecutionContext, + the `statement` and `parameters` datamembers must be + initialized after this statement is complete. """ raise NotImplementedError() def post_execution(self): """Called after the execution of a compiled statement. - + If a compiled statement was passed to this ExecutionContext, - the `last_insert_ids`, `last_inserted_params`, etc. - datamembers should be available after this method - completes. + the `last_insert_ids`, `last_inserted_params`, etc. + datamembers should be available after this method completes. """ raise NotImplementedError() - + def result(self): - """return a result object corresponding to this ExecutionContext. - - Returns a ResultProxy.""" - + """Return a result object corresponding to this ExecutionContext. + + Returns a ResultProxy. + """ + raise NotImplementedError() - + def get_rowcount(self): """Return the count of rows updated/deleted for an UPDATE/DELETE statement.""" raise NotImplementedError() + def should_autocommit_compiled(self, compiled): + """return True if the given Compiled object refers to a "committable" statement.""" + + raise NotImplementedError() + + def should_autocommit_text(self, statement): + """Parse the given textual statement and return True if it refers to a "committable" statement""" + + raise NotImplementedError() + def last_inserted_ids(self): """Return the list of the primary key values for the last insert statement executed. This does not apply to straight textual clauses; only to - ``sql.Insert`` objects compiled against a ``schema.Table`` object. - The order of - items in the list is the same as that of the Table's - 'primary_key' attribute. - + ``sql.Insert`` objects compiled against a ``schema.Table`` + object. The order of items in the list is the same as that of + the Table's 'primary_key' attribute. """ raise NotImplementedError() @@ -379,15 +411,13 @@ class ExecutionContext(object): raise NotImplementedError() def lastrow_has_defaults(self): - """Return True if the last row INSERTED via a compiled insert statement contained PassiveDefaults. - - The presence of PassiveDefaults indicates that the database - inserted data beyond that which we passed to the query - programmatically. + """Return True if the last INSERT or UPDATE row contained + inlined or database-side defaults. """ raise NotImplementedError() + class Compiled(object): """Represent a compiled SQL expression. @@ -401,55 +431,47 @@ class Compiled(object): defaults. """ - def __init__(self, dialect, statement, parameters, bind=None): + def __init__(self, dialect, statement, column_keys=None, bind=None): """Construct a new ``Compiled`` object. + dialect + ``Dialect`` to compile against. + statement ``ClauseElement`` to be compiled. - parameters - Optional dictionary indicating a set of bind parameters - specified with this ``Compiled`` object. These parameters - are the *default* values corresponding to the - ``ClauseElement``'s ``_BindParamClauses`` when the - ``Compiled`` is executed. In the case of an ``INSERT`` or - ``UPDATE`` statement, these parameters will also result in - the creation of new ``_BindParamClause`` objects for each - key and will also affect the generated column list in an - ``INSERT`` statement and the ``SET`` clauses of an - ``UPDATE`` statement. The keys of the parameter dictionary - can either be the string names of columns or - ``_ColumnClause`` objects. + column_keys + a list of column names to be compiled into an INSERT or UPDATE + statement. bind Optional Engine or Connection to compile this statement against. """ self.dialect = dialect self.statement = statement - self.parameters = parameters + self.column_keys = column_keys self.bind = bind self.can_execute = statement.supports_execution() - + def compile(self): """Produce the internal string representation of this element.""" - + raise NotImplementedError() - + def __str__(self): """Return the string text of the generated SQL statement.""" raise NotImplementedError() def get_params(self, **params): - """Deprecated. use construct_params(). (supports unicode names) - """ - + """Use construct_params(). (supports unicode names)""" return self.construct_params(params) + get_params = util.deprecated()(get_params) def construct_params(self, params): """Return the bind params for this compiled object. - params is a dict of string/object pairs whos + `params` is a dict of string/object pairs whos values will override bind values compiled in to the statement. """ @@ -460,7 +482,7 @@ class Compiled(object): e = self.bind if e is None: - raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.") + raise exceptions.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.") return e._execute_compiled(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -490,11 +512,11 @@ class Connectable(object): def execute(self, object, *multiparams, **params): raise NotImplementedError() - engine = util.NotImplProperty("The Engine which this Connectable is associated with.") - dialect = util.NotImplProperty("Dialect which this Connectable is associated with.") + def execute_clauseelement(self, elem, multiparams=None, params=None): + raise NotImplementedError() class Connection(Connectable): - """Represent a single DBAPI connection returned from the underlying connection pool. + """Provides high-level functionality for a wrapped DB-API connection. Provides execution support for string-based SQL statements as well as ClauseElement, Compiled and DefaultGenerator objects. Provides @@ -503,305 +525,478 @@ class Connection(Connectable): The Connection object is **not** threadsafe. """ - def __init__(self, engine, connection=None, close_with_result=False): - self.__engine = engine + def __init__(self, engine, connection=None, close_with_result=False, + _branch=False): + """Construct a new Connection. + + Connection objects are typically constructed by an + [sqlalchemy.engine#Engine], see the ``connect()`` and + ``contextual_connect()`` methods of Engine. + """ + + self.engine = engine self.__connection = connection or engine.raw_connection() self.__transaction = None self.__close_with_result = close_with_result self.__savepoint_seq = 0 + self.__branch = _branch + self.__invalid = False + + def _branch(self): + """Return a new Connection which references this Connection's + engine and connection; but does not have close_with_result enabled, + and also whose close() method does nothing. + + This is used to execute "sub" statements within a single execution, + usually an INSERT statement. + """ + return Connection(self.engine, self.__connection, _branch=True) + + def dialect(self): + "Dialect used by this Connection." + + return self.engine.dialect + dialect = property(dialect) + + def closed(self): + """return True if this connection is closed.""" + + return not self.__invalid and '_Connection__connection' not in self.__dict__ + closed = property(closed) + + def invalidated(self): + """return True if this connection was invalidated.""" + + return self.__invalid + invalidated = property(invalidated) + + def connection(self): + "The underlying DB-API connection managed by this Connection." - def _get_connection(self): try: return self.__connection except AttributeError: + if self.__invalid: + if self.__transaction is not None: + raise exceptions.InvalidRequestError("Can't reconnect until invalid transaction is rolled back") + self.__connection = self.engine.raw_connection() + self.__invalid = False + return self.__connection raise exceptions.InvalidRequestError("This Connection is closed") + connection = property(connection) - def _branch(self): - """return a new Connection which references this Connection's - engine and connection; but does not have close_with_result enabled.""" - - return Connection(self.__engine, self.__connection) - - engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.") - dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.") - connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.") - should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.") - properties = property(lambda s: s._get_connection().properties, - doc="A set of per-DBAPI connection properties.") + def should_close_with_result(self): + """Indicates if this Connection should be closed when a corresponding + ResultProxy is closed; this is essentially an auto-release mode. + """ + + return self.__close_with_result + should_close_with_result = property(should_close_with_result) + + def info(self): + """A collection of per-DB-API connection instance properties.""" + return self.connection.info + info = property(info) + + properties = property(info, doc="""An alias for the .info collection, will be removed in 0.5.""") def connect(self): - """connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly.""" + """Returns self. + + This ``Connectable`` interface method returns self, allowing + Connections to be used interchangably with Engines in most + situations that require a bind. + """ + return self def contextual_connect(self, **kwargs): - """contextual_connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly.""" + """Returns self. + + This ``Connectable`` interface method returns self, allowing + Connections to be used interchangably with Engines in most + situations that require a bind. + """ + return self - def invalidate(self): - """invalidate the underying DBAPI connection and immediately close this Connection. - - The underlying DBAPI connection is literally closed (if possible), and is discarded. - Its source connection pool will typically create a new connection to replace it, once - requested. + def invalidate(self, exception=None): + """Invalidate the underlying DBAPI connection associated with this Connection. + + The underlying DB-API connection is literally closed (if + possible), and is discarded. Its source connection pool will + typically lazily create a new connection to replace it. + + Upon the next usage, this Connection will attempt to reconnect + to the pool with a new connection. + + Transactions in progress remain in an "opened" state (even though + the actual transaction is gone); these must be explicitly + rolled back before a reconnect on this Connection can proceed. This + is to prevent applications from accidentally continuing their transactional + operations in a non-transactional state. + """ - - self.__connection.invalidate() - self.__connection = None + + if self.__connection.is_valid: + self.__connection.invalidate(exception) + del self.__connection + self.__invalid = True def detach(self): - """detach the underlying DBAPI connection from its connection pool. - - This Connection instance will remain useable. When closed, the - DBAPI connection will be literally closed and not returned to its pool. - The pool will typically create a new connection to replace it, once requested. - - This method can be used to insulate the rest of an application from a modified - state on a connection (such as a transaction isolation level or similar). + """Detach the underlying DB-API connection from its connection pool. + + This Connection instance will remain useable. When closed, + the DB-API connection will be literally closed and not + returned to its pool. The pool will typically lazily create a + new connection to replace the detached connection. + + This method can be used to insulate the rest of an application + from a modified state on a connection (such as a transaction + isolation level or similar). Also see + [sqlalchemy.interfaces#PoolListener] for a mechanism to modify + connection state when connections leave and return to their + connection pool. """ - + self.__connection.detach() - - def begin(self, nested=False): + + def begin(self): + """Begin a transaction and return a Transaction handle. + + Repeated calls to ``begin`` on the same Connection will create + a lightweight, emulated nested transaction. Only the + outermost transaction may ``commit``. Calls to ``commit`` on + inner transactions are ignored. Any transaction in the + hierarchy may ``rollback``, however. + """ + if self.__transaction is None: self.__transaction = RootTransaction(self) - elif nested: - self.__transaction = NestedTransaction(self, self.__transaction) else: return Transaction(self, self.__transaction) return self.__transaction def begin_nested(self): - return self.begin(nested=True) - + """Begin a nested transaction and return a Transaction handle. + + Nested transactions require SAVEPOINT support in the + underlying database. Any transaction in the hierarchy may + ``commit`` and ``rollback``, however the outermost transaction + still controls the overall ``commit`` or ``rollback`` of the + transaction of a whole. + """ + + if self.__transaction is None: + self.__transaction = RootTransaction(self) + else: + self.__transaction = NestedTransaction(self, self.__transaction) + return self.__transaction + def begin_twophase(self, xid=None): + """Begin a two-phase or XA transaction and return a Transaction handle. + + xid + the two phase transaction id. If not supplied, a random id + will be generated. + """ + if self.__transaction is not None: - raise exceptions.InvalidRequestError("Cannot start a two phase transaction when a transaction is already started.") + raise exceptions.InvalidRequestError( + "Cannot start a two phase transaction when a transaction " + "is already in progress.") if xid is None: - xid = "_sa_%032x" % random.randint(0,2**128) + xid = self.engine.dialect.create_xid(); self.__transaction = TwoPhaseTransaction(self, xid) return self.__transaction - + def recover_twophase(self): - return self.__engine.dialect.do_recover_twophase(self) - + return self.engine.dialect.do_recover_twophase(self) + def rollback_prepared(self, xid, recover=False): - self.__engine.dialect.do_rollback_twophase(self, xid, recover=recover) - + self.engine.dialect.do_rollback_twophase(self, xid, recover=recover) + def commit_prepared(self, xid, recover=False): - self.__engine.dialect.do_commit_twophase(self, xid, recover=recover) + self.engine.dialect.do_commit_twophase(self, xid, recover=recover) def in_transaction(self): + """Return True if a transaction is in progress.""" + return self.__transaction is not None def _begin_impl(self): - if self.__connection.is_valid: - self.__engine.logger.info("BEGIN") - try: - self.__engine.dialect.do_begin(self.connection) - except Exception, e: - raise exceptions.SQLError(None, None, e) + if self.engine._should_log_info: + self.engine.logger.info("BEGIN") + try: + self.engine.dialect.do_begin(self.connection) + except Exception, e: + self._handle_dbapi_exception(e, None, None, None) + raise def _rollback_impl(self): - if self.__connection.is_valid: - self.__engine.logger.info("ROLLBACK") + if not self.closed and not self.invalidated and self.__connection.is_valid: + if self.engine._should_log_info: + self.engine.logger.info("ROLLBACK") try: - self.__engine.dialect.do_rollback(self.connection) + self.engine.dialect.do_rollback(self.connection) + self.__transaction = None except Exception, e: - raise exceptions.SQLError(None, None, e) - self.__connection.close_open_cursors() - self.__transaction = None + self._handle_dbapi_exception(e, None, None, None) + raise + else: + self.__transaction = None def _commit_impl(self): - if self.__connection.is_valid: - self.__engine.logger.info("COMMIT") - try: - self.__engine.dialect.do_commit(self.connection) - except Exception, e: - raise exceptions.SQLError(None, None, e) - self.__transaction = None + if self.engine._should_log_info: + self.engine.logger.info("COMMIT") + try: + self.engine.dialect.do_commit(self.connection) + self.__transaction = None + except Exception, e: + self._handle_dbapi_exception(e, None, None, None) + raise def _savepoint_impl(self, name=None): if name is None: self.__savepoint_seq += 1 - name = '__sa_savepoint_%s' % self.__savepoint_seq + name = 'sa_savepoint_%s' % self.__savepoint_seq if self.__connection.is_valid: - self.__engine.dialect.do_savepoint(self, name) + self.engine.dialect.do_savepoint(self, name) return name - + def _rollback_to_savepoint_impl(self, name, context): if self.__connection.is_valid: - self.__engine.dialect.do_rollback_to_savepoint(self, name) + self.engine.dialect.do_rollback_to_savepoint(self, name) self.__transaction = context - + def _release_savepoint_impl(self, name, context): if self.__connection.is_valid: - self.__engine.dialect.do_release_savepoint(self, name) + self.engine.dialect.do_release_savepoint(self, name) self.__transaction = context - + def _begin_twophase_impl(self, xid): if self.__connection.is_valid: - self.__engine.dialect.do_begin_twophase(self, xid) - + self.engine.dialect.do_begin_twophase(self, xid) + def _prepare_twophase_impl(self, xid): if self.__connection.is_valid: assert isinstance(self.__transaction, TwoPhaseTransaction) - self.__engine.dialect.do_prepare_twophase(self, xid) - + self.engine.dialect.do_prepare_twophase(self, xid) + def _rollback_twophase_impl(self, xid, is_prepared): if self.__connection.is_valid: assert isinstance(self.__transaction, TwoPhaseTransaction) - self.__engine.dialect.do_rollback_twophase(self, xid, is_prepared) + self.engine.dialect.do_rollback_twophase(self, xid, is_prepared) self.__transaction = None def _commit_twophase_impl(self, xid, is_prepared): if self.__connection.is_valid: assert isinstance(self.__transaction, TwoPhaseTransaction) - self.__engine.dialect.do_commit_twophase(self, xid, is_prepared) + self.engine.dialect.do_commit_twophase(self, xid, is_prepared) self.__transaction = None - def _autocommit(self, statement): - """When no Transaction is present, this is called after executions to provide "autocommit" behavior.""" - # TODO: have the dialect determine if autocommit can be set on the connection directly without this - # extra step - if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip(), re.I): + def _autocommit(self, context): + """Possibly issue a commit. + + When no Transaction is present, this is called after statement + execution to provide "autocommit" behavior. Dialects may + inspect the statement to determine if a commit is actually + required. + """ + + # TODO: have the dialect determine if autocommit can be set on + # the connection directly without this extra step + if not self.in_transaction() and context.should_autocommit: self._commit_impl() def _autorollback(self): if not self.in_transaction(): self._rollback_impl() - + def close(self): + """Close this Connection.""" + try: - c = self.__connection + conn = self.__connection except AttributeError: return - self.__connection.close() - self.__connection = None + if not self.__branch: + conn.close() + self.__invalid = False del self.__connection def scalar(self, object, *multiparams, **params): + """Executes and returns the first column of the first row. + + The underlying result/cursor is closed after execution. + """ + return self.execute(object, *multiparams, **params).scalar() - def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) def execute(self, object, *multiparams, **params): + """Executes and returns a ResultProxy.""" + for c in type(object).__mro__: if c in Connection.executors: return Connection.executors[c](self, object, multiparams, params) else: - raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object))) + raise exceptions.InvalidRequestError("Unexecutable object type: " + str(type(object))) def _execute_default(self, default, multiparams=None, params=None): - return self.__engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) + return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) def _execute_text(self, statement, multiparams, params): parameters = self.__distill_params(multiparams, params) context = self.__create_execution_context(statement=statement, parameters=parameters) self.__execute_raw(context) + self._autocommit(context) return context.result() def __distill_params(self, multiparams, params): + """given arguments from the calling form *multiparams, **params, return a list + of bind parameter structures, usually a list of dictionaries. + + in the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists.""" + if multiparams is None or len(multiparams) == 0: - parameters = params or None - elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)): - parameters = multiparams[0] + if params: + return [params] + else: + return [{}] + elif len(multiparams) == 1: + if isinstance(multiparams[0], (list, tuple)): + if isinstance(multiparams[0][0], (list, tuple, dict)): + return multiparams[0] + else: + return [multiparams[0]] + elif isinstance(multiparams[0], dict): + return [multiparams[0]] + else: + return [[multiparams[0]]] else: - parameters = list(multiparams) - return parameters + if isinstance(multiparams[0], (list, tuple, dict)): + return multiparams + else: + return [multiparams] def _execute_function(self, func, multiparams, params): - return self._execute_clauseelement(func.select(), multiparams, params) - - def _execute_clauseelement(self, elem, multiparams=None, params=None): - executemany = multiparams is not None and len(multiparams) > 0 - if executemany: - param = multiparams[0] + return self.execute_clauseelement(func.select(), multiparams, params) + + def execute_clauseelement(self, elem, multiparams=None, params=None): + params = self.__distill_params(multiparams, params) + if params: + keys = params[0].keys() else: - param = params - return self._execute_compiled(elem.compile(dialect=self.dialect, parameters=param), multiparams, params) + keys = None + return self._execute_compiled(elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), distilled_params=params) - def _execute_compiled(self, compiled, multiparams=None, params=None): + def _execute_compiled(self, compiled, multiparams=None, params=None, distilled_params=None): """Execute a sql.Compiled object.""" if not compiled.can_execute: - raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled))) + raise exceptions.ArgumentError("Not an executable clause: %s" % (str(compiled))) + + if distilled_params is None: + distilled_params = self.__distill_params(multiparams, params) + context = self.__create_execution_context(compiled=compiled, parameters=distilled_params) - params = self.__distill_params(multiparams, params) - context = self.__create_execution_context(compiled=compiled, parameters=params) - context.pre_execution() self.__execute_raw(context) context.post_execution() + self._autocommit(context) return context.result() - - def __create_execution_context(self, **kwargs): - return self.__engine.dialect.create_execution_context(connection=self, **kwargs) - + def __execute_raw(self, context): - if logging.is_info_enabled(self.__engine.logger): - self.__engine.logger.info(context.statement) - self.__engine.logger.info(repr(context.parameters)) - if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)): - self.__executemany(context) + if context.executemany: + self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context) + else: + self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context) + + def _execute_ddl(self, ddl, params, multiparams): + if params: + schema_item, params = params[0], params[1:] else: - self.__execute(context) - self._autocommit(context.statement) + schema_item = None + return ddl(None, schema_item, self, *params, **multiparams) - def __execute(self, context): - if context.parameters is None: - if context.dialect.positional: - context.parameters = () + def _handle_dbapi_exception(self, e, statement, parameters, cursor): + if getattr(self, '_reentrant_error', False): + raise exceptions.DBAPIError.instance(None, None, e) + self._reentrant_error = True + try: + if not isinstance(e, self.dialect.dbapi.Error): + return + is_disconnect = self.dialect.is_disconnect(e) + if is_disconnect: + self.invalidate(e) + self.engine.dispose() else: - context.parameters = {} + if cursor: + cursor.close() + self._autorollback() + if self.__close_with_result: + self.close() + raise exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) + finally: + del self._reentrant_error + + def __create_execution_context(self, **kwargs): try: - context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context) + return self.engine.dialect.create_execution_context(connection=self, **kwargs) except Exception, e: - if self.dialect.is_disconnect(e): - self.__connection.invalidate(e=e) - self.engine.dispose() - self._autorollback() - if self.__close_with_result: - self.close() - raise exceptions.SQLError(context.statement, context.parameters, e) + self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None) + raise - def __executemany(self, context): + def _cursor_execute(self, cursor, statement, parameters, context=None): + if self.engine._should_log_info: + self.engine.logger.info(statement) + self.engine.logger.info(repr(parameters)) try: - context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context) + self.dialect.do_execute(cursor, statement, parameters, context=context) except Exception, e: - if self.dialect.is_disconnect(e): - self.__connection.invalidate(e=e) - self.engine.dispose() - self._autorollback() - if self.__close_with_result: - self.close() - raise exceptions.SQLError(context.statement, context.parameters, e) + self._handle_dbapi_exception(e, statement, parameters, cursor) + raise + + def _cursor_executemany(self, cursor, statement, parameters, context=None): + if self.engine._should_log_info: + self.engine.logger.info(statement) + self.engine.logger.info(repr(parameters)) + try: + self.dialect.do_executemany(cursor, statement, parameters, context=context) + except Exception, e: + self._handle_dbapi_exception(e, statement, parameters, cursor) + raise # poor man's multimethod/generic function thingy executors = { - sql._Function : _execute_function, - sql.ClauseElement : _execute_clauseelement, - sql.ClauseVisitor : _execute_compiled, - schema.SchemaItem:_execute_default, - str.__mro__[-2] : _execute_text + expression._Function: _execute_function, + expression.ClauseElement: execute_clauseelement, + Compiled: _execute_compiled, + schema.SchemaItem: _execute_default, + schema.DDL: _execute_ddl, + basestring: _execute_text } def create(self, entity, **kwargs): """Create a Table or Index given an appropriate Schema object.""" - return self.__engine.create(entity, connection=self, **kwargs) + return self.engine.create(entity, connection=self, **kwargs) def drop(self, entity, **kwargs): """Drop a Table or Index given an appropriate Schema object.""" - return self.__engine.drop(entity, connection=self, **kwargs) + return self.engine.drop(entity, connection=self, **kwargs) def reflecttable(self, table, include_columns=None): """Reflect the columns in the given string table name from the database.""" - return self.__engine.reflecttable(table, self, include_columns) + return self.engine.reflecttable(table, self, include_columns) def default_schema_name(self): - return self.__engine.dialect.get_default_schema_name(self) + return self.engine.dialect.get_default_schema_name(self) def run_callable(self, callable_): return callable_(self) @@ -817,24 +1012,45 @@ class Transaction(object): self._parent = parent or self self._is_active = True - connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction") - is_active = property(lambda s:s._is_active) + def connection(self): + "The Connection object referenced by this Transaction" + return self._connection + connection = property(connection) + + def is_active(self): + return self._is_active + is_active = property(is_active) + + def close(self): + """Close this transaction. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + """ + if not self._parent._is_active: + return + if self._parent is self: + self.rollback() def rollback(self): if not self._parent._is_active: return self._is_active = False self._do_rollback() - + def _do_rollback(self): self._parent.rollback() def commit(self): if not self._parent._is_active: raise exceptions.InvalidRequestError("This transaction is inactive") - self._is_active = False self._do_commit() - + self._is_active = False + def _do_commit(self): pass @@ -851,7 +1067,7 @@ class RootTransaction(Transaction): def __init__(self, connection): super(RootTransaction, self).__init__(connection, None) self._connection._begin_impl() - + def _do_rollback(self): self._connection._rollback_impl() @@ -862,7 +1078,7 @@ class NestedTransaction(Transaction): def __init__(self, connection, parent): super(NestedTransaction, self).__init__(connection, parent) self._savepoint = self._connection._savepoint_impl() - + def _do_rollback(self): self._connection._rollback_to_savepoint_impl(self._savepoint, self._parent) @@ -875,16 +1091,16 @@ class TwoPhaseTransaction(Transaction): self._is_prepared = False self.xid = xid self._connection._begin_twophase_impl(self.xid) - + def prepare(self): if not self._parent._is_active: raise exceptions.InvalidRequestError("This transaction is inactive") self._connection._prepare_twophase_impl(self.xid) self._is_prepared = True - + def _do_rollback(self): self._connection._rollback_twophase_impl(self.xid, self._is_prepared) - + def commit(self): self._connection._commit_twophase_impl(self.xid, self._is_prepared) @@ -897,18 +1113,22 @@ class Engine(Connectable): def __init__(self, pool, dialect, url, echo=None): self.pool = pool self.url = url - self._dialect=dialect + self.dialect=dialect self.echo = echo - self.logger = logging.instance_logger(self) + self.engine = self + self.logger = logging.instance_logger(self, echoflag=echo) + + def name(self): + "String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``." + + return sys.modules[self.dialect.__module__].descriptor()['name'] + name = property(name) - name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'], doc="String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``.") - engine = property(lambda s:s) - dialect = property(lambda s:s._dialect, doc="the [sqlalchemy.engine#Dialect] in use by this engine.") echo = logging.echo_property() - + def __repr__(self): return 'Engine(%s)' % str(self.url) - + def dispose(self): self.pool.dispose() self.pool = self.pool.recreate() @@ -930,15 +1150,14 @@ class Engine(Connectable): finally: connection.close() - def _func(self): - return sql._FunctionGenerator(bind=self) - - func = property(_func) + def func(self): + return expression._FunctionGenerator(bind=self) + func = property(func) def text(self, text, *args, **kwargs): """Return a sql.text() object for performing literal queries.""" - return sql.text(text, bind=self, *args, **kwargs) + return expression.text(text, bind=self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: @@ -946,7 +1165,7 @@ class Engine(Connectable): else: conn = connection try: - visitorcallable(conn, **kwargs).traverse(element) + visitorcallable(self.dialect, conn, **kwargs).traverse(element) finally: if connection is None: conn.close() @@ -995,12 +1214,16 @@ class Engine(Connectable): def scalar(self, statement, *multiparams, **params): return self.execute(statement, *multiparams, **params).scalar() + def execute_clauseelement(self, elem, multiparams=None, params=None): + connection = self.contextual_connect(close_with_result=True) + return connection.execute_clauseelement(elem, multiparams, params) + def _execute_compiled(self, compiled, multiparams, params): connection = self.contextual_connect(close_with_result=True) return connection._execute_compiled(compiled, multiparams, params) - def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -1013,7 +1236,33 @@ class Engine(Connectable): This Connection is meant to be used by the various "auto-connecting" operations. """ - return Connection(self, close_with_result=close_with_result, **kwargs) + return Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs) + + def table_names(self, schema=None, connection=None): + """Return a list of all table names available in the database. + + schema: + Optional, retrieve names from a non-default schema. + + connection: + Optional, use a specified connection. Default is the + ``contextual_connect`` for this ``Engine``. + """ + + if connection is None: + conn = self.contextual_connect() + else: + conn = connection + if not schema: + try: + schema = self.dialect.get_default_schema_name(conn) + except NotImplementedError: + pass + try: + return self.dialect.table_names(conn, schema) + finally: + if connection is None: + conn.close() def reflecttable(self, table, connection=None, include_columns=None): """Given a Table object, reflects its columns and properties from the database.""" @@ -1032,17 +1281,93 @@ class Engine(Connectable): return self.run_callable(lambda c: self.dialect.has_table(c, table_name, schema=schema)) def raw_connection(self): - """Return a DBAPI connection.""" + """Return a DB-API connection.""" - return self.pool.connect() + return self.pool.unique_connection() - def log(self, msg): - """Log a message using this SQLEngine's logger stream.""" - self.logger.info(msg) +class RowProxy(object): + """Proxy a single cursor row for a parent ResultProxy. + + Mostly follows "ordered dictionary" behavior, mapping result + values to the string-based column name, the integer position of + the result in the row, as well as Column instances which can be + mapped to the original Columns that produced this result set (for + results that correspond to constructed SQL expressions). + """ + + def __init__(self, parent, row): + """RowProxy objects are constructed by ResultProxy objects.""" + + self.__parent = parent + self.__row = row + if self.__parent._ResultProxy__echo: + self.__parent.context.engine.logger.debug("Row " + repr(row)) + + def close(self): + """Close the parent ResultProxy.""" + + self.__parent.close() + + def __contains__(self, key): + return self.__parent._has_key(self.__row, key) + + def __len__(self): + return len(self.__row) + + def __iter__(self): + for i in xrange(len(self.__row)): + yield self.__parent._get_col(self.__row, i) + + def __eq__(self, other): + return ((other is self) or + (other == tuple([self.__parent._get_col(self.__row, key) + for key in xrange(len(self.__row))]))) + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return repr(tuple(self)) + + def has_key(self, key): + """Return True if this RowProxy contains the given key.""" + + return self.__parent._has_key(self.__row, key) + + def __getitem__(self, key): + return self.__parent._get_col(self.__row, key) + + def __getattr__(self, name): + try: + return self.__parent._get_col(self.__row, name) + except KeyError, e: + raise AttributeError(e.args[0]) + + def items(self): + """Return a list of tuples, each tuple containing a key/value pair.""" + + return [(key, getattr(self, key)) for key in self.keys()] + + def keys(self): + """Return the list of keys as strings represented by this RowProxy.""" + + return self.__parent.keys + + def values(self): + """Return the values represented by this RowProxy as a list.""" + + return list(self) + + +class BufferedColumnRow(RowProxy): + def __init__(self, parent, row): + row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))] + super(BufferedColumnRow, self).__init__(parent, row) + class ResultProxy(object): - """Wraps a DBAPI cursor object to provide easier access to row columns. + """Wraps a DB-API cursor object to provide easier access to row columns. Individual columns may be accessed by their integer position, case-insensitive column name, or by ``schema.Column`` @@ -1057,19 +1382,13 @@ class ResultProxy(object): col3 = row[mytable.c.mycol] # access via Column object. ResultProxy also contains a map of TypeEngine objects and will - invoke the appropriate ``convert_result_value()`` method before + invoke the appropriate ``result_processor()`` method before returning columns, as well as the ExecutionContext corresponding to the statement execution. It provides several methods for which to obtain information from the underlying ExecutionContext. """ - class AmbiguousColumn(object): - def __init__(self, key): - self.key = key - def dialect_impl(self, dialect): - return self - def convert_result_value(self, arg, engine): - raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) + _process_row = RowProxy def __init__(self, context): """ResultProxy objects are constructed via the execute() method on SQLEngine.""" @@ -1077,51 +1396,72 @@ class ResultProxy(object): self.dialect = context.dialect self.closed = False self.cursor = context.cursor - self.__echo = logging.is_debug_enabled(context.engine.logger) - if context.is_select(): + self.connection = context.root_connection + self.__echo = context.engine._should_log_info + if context.returns_rows: self._init_metadata() self._rowcount = None else: self._rowcount = context.get_rowcount() self.close() - - connection = property(lambda self:self.context.connection) - def _get_rowcount(self): + + def rowcount(self): if self._rowcount is not None: return self._rowcount else: return self.context.get_rowcount() - rowcount = property(_get_rowcount) - lastrowid = property(lambda s:s.cursor.lastrowid) - out_parameters = property(lambda s:s.context.out_parameters) - + rowcount = property(rowcount) + + def lastrowid(self): + return self.cursor.lastrowid + lastrowid = property(lastrowid) + + def out_parameters(self): + return self.context.out_parameters + out_parameters = property(out_parameters) + def _init_metadata(self): - if hasattr(self, '_ResultProxy__props'): - return self.__props = {} self._key_cache = self._create_key_cache() self.__keys = [] metadata = self.cursor.description if metadata is not None: - typemap = self.dialect.dbapi_type_map() + typemap = self.dialect.dbapi_type_map for i, item in enumerate(metadata): - # sqlite possibly prepending table name to colnames so strip - colname = self.dialect.decode_result_columnname(item[0].split('.')[-1]) - if self.context.typemap is not None: - type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE)) + colname = item[0].decode(self.dialect.encoding) + + if '.' in colname: + # sqlite will in some circumstances prepend table name to colnames, so strip + origname = colname + colname = colname.split('.')[-1] else: - type = typemap.get(item[1], types.NULLTYPE) + origname = None + + if self.context.result_map: + try: + (name, obj, type_) = self.context.result_map[colname.lower()] + except KeyError: + (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) + else: + (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) + + rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i) + + if self.__props.setdefault(name.lower(), rec) is not rec: + self.__props[name.lower()] = (type_, self.__ambiguous_processor(name), 0) - rec = (type, type.dialect_impl(self.dialect), i) + # store the "origname" if we truncated (sqlite only) + if origname: + if self.__props.setdefault(origname.lower(), rec) is not rec: + self.__props[origname.lower()] = (type_, self.__ambiguous_processor(origname), 0) - if rec[0] is None: - raise exceptions.DBAPIError("None for metadata " + colname) - if self.__props.setdefault(colname.lower(), rec) is not rec: - self.__props[colname.lower()] = (type, ResultProxy.AmbiguousColumn(colname), 0) self.__keys.append(colname) self.__props[i] = rec + if obj: + for o in obj: + self.__props[o] = rec if self.__echo: self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata]))) @@ -1135,27 +1475,35 @@ class ResultProxy(object): matches it to the appropriate key we got from the result set's metadata; then cache it locally for quick re-access.""" - if isinstance(key, int) and key in props: + if isinstance(key, basestring): + key = key.lower() + try: rec = props[key] - elif isinstance(key, basestring) and key.lower() in props: - rec = props[key.lower()] - elif isinstance(key, sql.ColumnElement): - label = context.column_labels.get(key._label, key.name).lower() - if label in props: - rec = props[label] - - if not "rec" in locals(): + except KeyError: + # fallback for targeting a ColumnElement to a textual expression + # this is a rare use case which only occurs when matching text() + # constructs to ColumnElements + if isinstance(key, expression.ColumnElement): + if key._label and key._label.lower() in props: + return props[key._label.lower()] + elif hasattr(key, 'name') and key.name.lower() in props: + return props[key.name.lower()] raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) return rec return util.PopulateDict(lookup_key) - + + def __ambiguous_processor(self, colname): + def process(value): + raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname) + return process + def close(self): - """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution. + """Close this ResultProxy, and the underlying DB-API cursor corresponding to the execution. If this ResultProxy was generated from an implicit execution, the underlying Connection will also be closed (returns the - underlying DBAPI connection to the connection pool.) + underlying DB-API connection to the connection pool.) This method is also called automatically when all result rows are exhausted. @@ -1165,11 +1513,15 @@ class ResultProxy(object): self.cursor.close() if self.connection.should_close_with_result: self.connection.close() - - keys = property(lambda s:s.__keys) - + + def keys(self): + return self.__keys + keys = property(keys) + def _has_key(self, row, key): try: + # _key_cache uses __missing__ in 2.5, so not much alternative + # to catching KeyError self._key_cache[key] return True except KeyError: @@ -1215,78 +1567,126 @@ class ResultProxy(object): return self.context.lastrow_has_defaults() - def supports_sane_rowcount(self): - """Return ``supports_sane_rowcount()`` from the underlying ExecutionContext. + def postfetch_cols(self): + """Return ``postfetch_cols()`` from the underlying ExecutionContext. See ExecutionContext for details. """ + return self.context.postfetch_cols + + def prefetch_cols(self): + return self.context.prefetch_cols + + def supports_sane_rowcount(self): + """Return ``supports_sane_rowcount`` from the dialect. + + """ + return self.dialect.supports_sane_rowcount - return self.context.supports_sane_rowcount() + def supports_sane_multi_rowcount(self): + """Return ``supports_sane_multi_rowcount`` from the dialect. + """ + + return self.dialect.supports_sane_multi_rowcount def _get_col(self, row, key): - rec = self._key_cache[key] - return rec[1].convert_result_value(row[rec[2]], self.dialect) - + try: + type_, processor, index = self._key_cache[key] + except TypeError: + # the 'slice' use case is very infrequent, + # so we use an exception catch to reduce conditionals in _get_col + if isinstance(key, slice): + indices = key.indices(len(row)) + return tuple([self._get_col(row, i) for i in xrange(*indices)]) + else: + raise + + if processor: + return processor(row[index]) + else: + return row[index] + def _fetchone_impl(self): return self.cursor.fetchone() + def _fetchmany_impl(self, size=None): return self.cursor.fetchmany(size) + def _fetchall_impl(self): return self.cursor.fetchall() - - def _process_row(self, row): - return RowProxy(self, row) - + def fetchall(self): - """Fetch all rows, just like DBAPI ``cursor.fetchall()``.""" + """Fetch all rows, just like DB-API ``cursor.fetchall()``.""" - l = [self._process_row(row) for row in self._fetchall_impl()] - self.close() - return l + try: + process_row = self._process_row + l = [process_row(self, row) for row in self._fetchall_impl()] + self.close() + return l + except Exception, e: + self.connection._handle_dbapi_exception(e, None, None, self.cursor) + raise def fetchmany(self, size=None): - """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``.""" + """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``.""" - l = [self._process_row(row) for row in self._fetchmany_impl(size)] - if len(l) == 0: - self.close() - return l + try: + process_row = self._process_row + l = [process_row(self, row) for row in self._fetchmany_impl(size)] + if len(l) == 0: + self.close() + return l + except Exception, e: + self.connection._handle_dbapi_exception(e, None, None, self.cursor) + raise def fetchone(self): - """Fetch one row, just like DBAPI ``cursor.fetchone()``.""" - row = self._fetchone_impl() - if row is not None: - return self._process_row(row) - else: - self.close() - return None + """Fetch one row, just like DB-API ``cursor.fetchone()``.""" + try: + row = self._fetchone_impl() + if row is not None: + return self._process_row(self, row) + else: + self.close() + return None + except Exception, e: + self.connection._handle_dbapi_exception(e, None, None, self.cursor) + raise def scalar(self): """Fetch the first column of the first row, and close the result set.""" - row = self._fetchone_impl() + try: + row = self._fetchone_impl() + except Exception, e: + self.connection._handle_dbapi_exception(e, None, None, self.cursor) + raise + try: if row is not None: - return self._process_row(row)[0] + return self._process_row(self, row)[0] else: return None finally: self.close() class BufferedRowResultProxy(ResultProxy): - """``ResultProxy`` that buffers the contents of a selection of rows before - ``fetchone()`` is called. This is to allow the results of - ``cursor.description`` to be available immediately, when interfacing - with a DBAPI that requires rows to be consumed before this information is - available (currently psycopg2, when used with server-side cursors). - - The pre-fetching behavior fetches only one row initially, and then grows - its buffer size by a fixed amount with each successive need for additional - rows up to a size of 100. + """A ResultProxy with row buffering behavior. + + ``ResultProxy`` that buffers the contents of a selection of rows + before ``fetchone()`` is called. This is to allow the results of + ``cursor.description`` to be available immediately, when + interfacing with a DB-API that requires rows to be consumed before + this information is available (currently psycopg2, when used with + server-side cursors). + + The pre-fetching behavior fetches only one row initially, and then + grows its buffer size by a fixed amount with each successive need + for additional rows up to a size of 100. """ def _init_metadata(self): self.__buffer_rows() super(BufferedRowResultProxy, self)._init_metadata() - + # this is a "growth chart" for the buffering of rows. # each successive __buffer_rows call will use the next # value in the list for the buffer size until the max @@ -1298,13 +1698,13 @@ class BufferedRowResultProxy(ResultProxy): 20 : 50, 50 : 100 } - + def __buffer_rows(self): size = getattr(self, '_bufsize', 1) self.__rowbuffer = self.cursor.fetchmany(size) #self.context.engine.logger.debug("Buffered %d rows" % size) self._bufsize = self.size_growth.get(size, size) - + def _fetchone_impl(self): if self.closed: return None @@ -1322,27 +1722,35 @@ class BufferedRowResultProxy(ResultProxy): break result.append(row) return result - + def _fetchall_impl(self): return self.__rowbuffer + list(self.cursor.fetchall()) class BufferedColumnResultProxy(ResultProxy): - """``ResultProxy`` that loads all columns into memory each time fetchone() is - called. If fetchmany() or fetchall() are called, the full grid of results - is fetched. This is to operate with databases where result rows contain "live" - results that fall out of scope unless explicitly fetched. Currently this includes - just cx_Oracle LOB objects, but this behavior is known to exist in other DBAPIs as - well (Pygresql, currently unsupported). - + """A ResultProxy with column buffering behavior. + + ``ResultProxy`` that loads all columns into memory each time + fetchone() is called. If fetchmany() or fetchall() are called, + the full grid of results is fetched. This is to operate with + databases where result rows contain "live" results that fall out + of scope unless explicitly fetched. Currently this includes just + cx_Oracle LOB objects, but this behavior is known to exist in + other DB-APIs as well (Pygresql, currently unsupported). """ + _process_row = BufferedColumnRow + def _get_col(self, row, key): - rec = self._key_cache[key] - return row[rec[2]] - - def _process_row(self, row): - sup = super(BufferedColumnResultProxy, self) - row = [sup._get_col(row, i) for i in xrange(len(row))] - return RowProxy(self, row) + try: + rec = self._key_cache[key] + return row[rec[2]] + except TypeError: + # the 'slice' use case is very infrequent, + # so we use an exception catch to reduce conditionals in _get_col + if isinstance(key, slice): + indices = key.indices(len(row)) + return tuple([self._get_col(row, i) for i in xrange(*indices)]) + else: + raise def fetchall(self): l = [] @@ -1364,77 +1772,6 @@ class BufferedColumnResultProxy(ResultProxy): l.append(row) return l -class RowProxy(object): - """Proxy a single cursor row for a parent ResultProxy. - - Mostly follows "ordered dictionary" behavior, mapping result - values to the string-based column name, the integer position of - the result in the row, as well as Column instances which can be - mapped to the original Columns that produced this result set (for - results that correspond to constructed SQL expressions). - """ - - def __init__(self, parent, row): - """RowProxy objects are constructed by ResultProxy objects.""" - - self.__parent = parent - self.__row = row - if self.__parent._ResultProxy__echo: - self.__parent.context.engine.logger.debug("Row " + repr(row)) - - def close(self): - """Close the parent ResultProxy.""" - - self.__parent.close() - - def __contains__(self, key): - return self.__parent._has_key(self.__row, key) - - def __iter__(self): - for i in range(0, len(self.__row)): - yield self.__parent._get_col(self.__row, i) - - def __eq__(self, other): - return (other is self) or (other == tuple([self.__parent._get_col(self.__row, key) for key in range(0, len(self.__row))])) - - def __repr__(self): - return repr(tuple([self.__parent._get_col(self.__row, key) for key in range(0, len(self.__row))])) - - def has_key(self, key): - """Return True if this RowProxy contains the given key.""" - - return self.__parent._has_key(self.__row, key) - - def __getitem__(self, key): - if isinstance(key, slice): - indices = key.indices(len(self)) - return tuple([self.__parent._get_col(self.__row, i) for i in range(*indices)]) - else: - return self.__parent._get_col(self.__row, key) - - def __getattr__(self, name): - try: - return self.__parent._get_col(self.__row, name) - except KeyError, e: - raise AttributeError(e.args[0]) - - def items(self): - """Return a list of tuples, each tuple containing a key/value pair.""" - - return [(key, getattr(self, key)) for key in self.keys()] - - def keys(self): - """Return the list of keys as strings represented by this RowProxy.""" - - return self.__parent.keys - - def values(self): - """Return the values represented by this RowProxy as a list.""" - - return list(self) - - def __len__(self): - return len(self.__row) class SchemaIterator(schema.SchemaVisitor): """A visitor that can gather text into a buffer and execute the contents of the buffer.""" @@ -1470,11 +1807,9 @@ class DefaultRunner(schema.SchemaVisitor): def __init__(self, context): self.context = context - # branch the connection so it doesnt close after result - self.connection = context.connection._branch() - - dialect = property(lambda self:self.context.dialect) - + self.dialect = context.dialect + self.cursor = context.cursor + def get_column_default(self, column): if column.default is not None: return self.traverse_single(column.default) @@ -1499,17 +1834,26 @@ class DefaultRunner(schema.SchemaVisitor): def visit_sequence(self, seq): """Do nothing. - Sequences are not supported by default. """ return None def exec_default_sql(self, default): - c = sql.select([default.arg]).compile(bind=self.connection) - return self.connection._execute_compiled(c).scalar() + conn = self.context.connection + c = expression.select([default.arg]).compile(bind=conn) + return conn._execute_compiled(c).scalar() + + def execute_string(self, stmt, params=None): + """execute a string statement, using the raw cursor, + and return a scalar result.""" + conn = self.context._connection + if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements: + stmt = stmt.encode(self.dialect.encoding) + conn._cursor_execute(self.cursor, stmt, params) + return self.cursor.fetchone()[0] def visit_column_onupdate(self, onupdate): - if isinstance(onupdate.arg, sql.ClauseElement): + if isinstance(onupdate.arg, expression.ClauseElement): return self.exec_default_sql(onupdate) elif callable(onupdate.arg): return onupdate.arg(self.context) @@ -1517,9 +1861,33 @@ class DefaultRunner(schema.SchemaVisitor): return onupdate.arg def visit_column_default(self, default): - if isinstance(default.arg, sql.ClauseElement): + if isinstance(default.arg, expression.ClauseElement): return self.exec_default_sql(default) elif callable(default.arg): return default.arg(self.context) else: return default.arg + + +def connection_memoize(key): + """Decorator, memoize a function in a connection.info stash. + + Only applicable to functions which take no arguments other than a + connection. The memo will be stored in ``connection.info[key]``. + + """ + def decorate(fn): + spec = inspect.getargspec(fn) + assert len(spec[0]) == 2 + assert spec[0][1] == 'connection' + assert spec[1:3] == (None, None) + + def decorated(self, connection): + try: + return connection.info[key] + except KeyError: + connection.info[key] = val = fn(self, connection) + return val + + return util.function_named(decorated, fn.__name__) + return decorate diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 962e2ab606..3c1721f9d9 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1,40 +1,74 @@ # engine/default.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Provide default implementations of per-dialect sqlalchemy.engine classes""" +"""Default implementations of per-dialect sqlalchemy.engine classes. -from sqlalchemy import schema, exceptions, sql, types -import sys, re +These are semi-private implementation classes which are only of importance +to database dialect authors; dialects will usually use the classes here +as the base class for their own corresponding classes. + +""" + + +import re, random from sqlalchemy.engine import base +from sqlalchemy.sql import compiler, expression + + +AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', + re.I | re.UNICODE) +SELECT_REGEXP = re.compile(r'\s*SELECT', re.I | re.UNICODE) class DefaultDialect(base.Dialect): """Default implementation of Dialect""" - def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs): + schemagenerator = compiler.SchemaGenerator + schemadropper = compiler.SchemaDropper + statement_compiler = compiler.DefaultCompiler + preparer = compiler.IdentifierPreparer + defaultrunner = base.DefaultRunner + supports_alter = True + supports_unicode_statements = False + max_identifier_length = 9999 + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + preexecute_pk_sequences = False + supports_pk_autoincrement = True + dbapi_type_map = {} + default_paramstyle = 'named' + + def __init__(self, convert_unicode=False, assert_unicode=False, encoding='utf-8', paramstyle=None, dbapi=None, **kwargs): self.convert_unicode = convert_unicode + self.assert_unicode = assert_unicode self.encoding = encoding self.positional = False self._ischema = None self.dbapi = dbapi - self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle) - - def decode_result_columnname(self, name): - """decode a name found in cursor.description to a unicode object.""" - - return name.decode(self.encoding) - - def dbapi_type_map(self): - # most DBAPIs have problems with this (such as, psycocpg2 types - # are unhashable). So far Oracle can return it. - - return {} - - def create_execution_context(self, **kwargs): - return DefaultExecutionContext(self, **kwargs) + if paramstyle is not None: + self.paramstyle = paramstyle + elif self.dbapi is not None: + self.paramstyle = self.dbapi.paramstyle + else: + self.paramstyle = self.default_paramstyle + self.positional = self.paramstyle in ('qmark', 'format', 'numeric') + self.identifier_preparer = self.preparer(self) + + # preexecute_sequences was renamed preexecute_pk_sequences. If a + # subclass has the older property, proxy the new name to the subclass's + # property. + # TODO: remove @ 0.5.0 + if (hasattr(self, 'preexecute_sequences') and + isinstance(getattr(type(self), 'preexecute_pk_sequences'), bool)): + setattr(type(self), 'preexecute_pk_sequences', + property(lambda s: s.preexecute_sequences, doc=( + "Proxy to deprecated preexecute_sequences attribute."))) + + def create_execution_context(self, connection, **kwargs): + return DefaultExecutionContext(self, connection, **kwargs) def type_descriptor(self, typeobj): """Provide a database-specific ``TypeEngine`` object, given @@ -47,24 +81,10 @@ class DefaultDialect(base.Dialect): typeobj = typeobj() return typeobj - def supports_unicode_statements(self): - """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" - return False - def max_identifier_length(self): - # TODO: probably raise this and fill out - # db modules better - return 9999 - - def supports_alter(self): - return True - def oid_column_name(self, column): return None - def supports_sane_rowcount(self): - return True - def do_begin(self, connection): """Implementations might want to put logic here for turning autocommit on/off, etc. @@ -77,7 +97,6 @@ class DefaultDialect(base.Dialect): autocommit on/off, etc. """ - #print "ENGINE ROLLBACK ON ", connection.connection connection.rollback() def do_commit(self, connection): @@ -85,99 +104,116 @@ class DefaultDialect(base.Dialect): autocommit on/off, etc. """ - #print "ENGINE COMMIT ON ", connection.connection connection.commit() - + + def create_xid(self): + """Create a random two-phase transaction ID. + + This id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). Its format is unspecified.""" + + return "_sa_%032x" % random.randint(0,2**128) + def do_savepoint(self, connection, name): - connection.execute(sql.SavepointClause(name)) + connection.execute(expression.SavepointClause(name)) def do_rollback_to_savepoint(self, connection, name): - connection.execute(sql.RollbackToSavepointClause(name)) + connection.execute(expression.RollbackToSavepointClause(name)) def do_release_savepoint(self, connection, name): - connection.execute(sql.ReleaseSavepointClause(name)) + connection.execute(expression.ReleaseSavepointClause(name)) - def do_executemany(self, cursor, statement, parameters, **kwargs): + def do_executemany(self, cursor, statement, parameters, context=None): cursor.executemany(statement, parameters) - def do_execute(self, cursor, statement, parameters, **kwargs): + def do_execute(self, cursor, statement, parameters, context=None): cursor.execute(statement, parameters) - def defaultrunner(self, context): - return base.DefaultRunner(context) - def is_disconnect(self, e): return False - - def _set_paramstyle(self, style): - self._paramstyle = style - self._figure_paramstyle(style) - paramstyle = property(lambda s:s._paramstyle, _set_paramstyle) - - - def _figure_paramstyle(self, paramstyle=None, default='named'): - if paramstyle is not None: - self._paramstyle = paramstyle - elif self.dbapi is not None: - self._paramstyle = self.dbapi.paramstyle - else: - self._paramstyle = default - - if self._paramstyle == 'named': - self.positional=False - elif self._paramstyle == 'pyformat': - self.positional=False - elif self._paramstyle == 'qmark' or self._paramstyle == 'format' or self._paramstyle == 'numeric': - # for positional, use pyformat internally, ANSICompiler will convert - # to appropriate character upon compilation - self.positional = True - else: - raise exceptions.DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle) - - def _get_ischema(self): - if self._ischema is None: - import sqlalchemy.databases.information_schema as ischema - self._ischema = ischema.ISchema(self) - return self._ischema - ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") class DefaultExecutionContext(base.ExecutionContext): def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): self.dialect = dialect - self.connection = connection + self._connection = self.root_connection = connection self.compiled = compiled - + self.engine = connection.engine + if compiled is not None: - self.typemap = compiled.typemap - self.column_labels = compiled.column_labels - self.statement = unicode(compiled) - if parameters is None: - self.compiled_parameters = compiled.construct_params({}) - elif not isinstance(parameters, (list, tuple)): - self.compiled_parameters = compiled.construct_params(parameters) + # compiled clauseelement. process bind params, process table defaults, + # track collections used by ResultProxy to target and process results + + self.processors = dict([ + (key, value) for key, value in + [( + compiled.bind_names[bindparam], + bindparam.bind_processor(self.dialect) + ) for bindparam in compiled.bind_names] + if value is not None + ]) + + self.result_map = compiled.result_map + + if not dialect.supports_unicode_statements: + self.statement = unicode(compiled).encode(self.dialect.encoding) else: - self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters] - if len(self.compiled_parameters) == 1: - self.compiled_parameters = self.compiled_parameters[0] + self.statement = unicode(compiled) + + self.isinsert = compiled.isinsert + self.isupdate = compiled.isupdate + if isinstance(compiled.statement, expression._TextClause): + self.returns_rows = self.returns_rows_text(self.statement) + self.should_autocommit = compiled.statement._autocommit or self.should_autocommit_text(self.statement) + else: + self.returns_rows = self.returns_rows_compiled(compiled) + self.should_autocommit = getattr(compiled.statement, '_autocommit', False) or self.should_autocommit_compiled(compiled) + + if not parameters: + self.compiled_parameters = [compiled.construct_params()] + self.executemany = False + else: + self.compiled_parameters = [compiled.construct_params(m) for m in parameters] + self.executemany = len(parameters) > 1 + + self.cursor = self.create_cursor() + if self.isinsert or self.isupdate: + self.__process_defaults() + self.parameters = self.__convert_compiled_params(self.compiled_parameters) + elif statement is not None: - self.typemap = self.column_labels = None + # plain text statement. + self.result_map = None self.parameters = self.__encode_param_keys(parameters) - self.statement = statement + self.executemany = len(parameters) > 1 + if isinstance(statement, unicode) and not dialect.supports_unicode_statements: + self.statement = statement.encode(self.dialect.encoding) + else: + self.statement = statement + self.isinsert = self.isupdate = False + self.cursor = self.create_cursor() + self.returns_rows = self.returns_rows_text(statement) + self.should_autocommit = self.should_autocommit_text(statement) else: + # no statement. used for standalone ColumnDefault execution. self.statement = None - - if self.statement is not None and not dialect.supports_unicode_statements(): - self.statement = self.statement.encode(self.dialect.encoding) - - self.cursor = self.create_cursor() - - engine = property(lambda s:s.connection.engine) - + self.isinsert = self.isupdate = self.executemany = self.returns_rows = self.should_autocommit = False + self.cursor = self.create_cursor() + + connection = property(lambda s:s._connection._branch()) + def __encode_param_keys(self, params): - """apply string encoding to the keys of dictionary-based bind parameters""" - if self.dialect.positional or self.dialect.supports_unicode_statements(): - return params + """apply string encoding to the keys of dictionary-based bind parameters. + + This is only used executing textual, non-compiled SQL expressions.""" + + if self.dialect.positional or self.dialect.supports_unicode_statements: + if params: + return params + elif self.dialect.positional: + return [()] + else: + return [{}] else: def proc(d): # sigh, sometimes we get positional arguments with a dialect @@ -185,50 +221,71 @@ class DefaultExecutionContext(base.ExecutionContext): if not isinstance(d, dict): return d return dict([(k.encode(self.dialect.encoding), d[k]) for k in d]) - if isinstance(params, list): - return [proc(d) for d in params] - else: - return proc(params) - - def __convert_compiled_params(self, parameters): - executemany = parameters is not None and isinstance(parameters, list) - encode = not self.dialect.supports_unicode_statements() - # the bind params are a CompiledParams object. but all the DBAPI's hate - # that object (or similar). so convert it to a clean - # dictionary/list/tuple of dictionary/tuple of list - if parameters is not None: - if self.dialect.positional: - if executemany: - parameters = [p.get_raw_list() for p in parameters] - else: - parameters = parameters.get_raw_list() - else: - if executemany: - parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters] + return [proc(d) for d in params] or [{}] + + def __convert_compiled_params(self, compiled_parameters): + """convert the dictionary of bind parameter values into a dict or list + to be sent to the DBAPI's execute() or executemany() method. + """ + + processors = self.processors + parameters = [] + if self.dialect.positional: + for compiled_params in compiled_parameters: + param = [] + for key in self.compiled.positiontup: + if key in processors: + param.append(processors[key](compiled_params[key])) + else: + param.append(compiled_params[key]) + parameters.append(param) + else: + encode = not self.dialect.supports_unicode_statements + for compiled_params in compiled_parameters: + param = {} + if encode: + encoding = self.dialect.encoding + for key in compiled_params: + if key in processors: + param[key.encode(encoding)] = processors[key](compiled_params[key]) + else: + param[key.encode(encoding)] = compiled_params[key] else: - parameters = parameters.get_raw_dict(encode_keys=encode) + for key in compiled_params: + if key in processors: + param[key] = processors[key](compiled_params[key]) + else: + param[key] = compiled_params[key] + parameters.append(param) return parameters - - def is_select(self): - """return TRUE if the statement is expected to have result rows.""" - - return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None + + def returns_rows_compiled(self, compiled): + return isinstance(compiled.statement, expression.Selectable) + + def returns_rows_text(self, statement): + return SELECT_REGEXP.match(statement) + + def should_autocommit_compiled(self, compiled): + return isinstance(compiled.statement, expression._UpdateBase) + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_REGEXP.match(statement) + def create_cursor(self): - return self.connection.connection.cursor() + return self._connection.connection.cursor() def pre_execution(self): self.pre_exec() - + def post_execution(self): self.post_exec() - + def result(self): return self.get_result_proxy() - + def pre_exec(self): - self._process_defaults() - self.parameters = self.__convert_compiled_params(self.compiled_parameters) + pass def post_exec(self): pass @@ -243,7 +300,10 @@ class DefaultExecutionContext(base.ExecutionContext): return self.cursor.rowcount def supports_sane_rowcount(self): - return self.dialect.supports_sane_rowcount() + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + return self.dialect.supports_sane_multi_rowcount def last_inserted_ids(self): return self._last_inserted_ids @@ -255,100 +315,84 @@ class DefaultExecutionContext(base.ExecutionContext): return self._last_updated_params def lastrow_has_defaults(self): - return self._lastrow_has_defaults + return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) def set_input_sizes(self): """Given a cursor and ClauseParameters, call the appropriate - style of ``setinputsizes()`` on the cursor, using DBAPI types + style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. """ - if isinstance(self.compiled_parameters, list): - plist = self.compiled_parameters - else: - plist = [self.compiled_parameters] + types = dict([ + (self.compiled.bind_names[bindparam], bindparam.type) + for bindparam in self.compiled.bind_names + ]) + if self.dialect.positional: inputsizes = [] - for params in plist[0:1]: - for key in params.positional: - typeengine = params.get_type(key) - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: - inputsizes.append(dbtype) - self.cursor.setinputsizes(*inputsizes) + for key in self.compiled.positiontup: + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if dbtype is not None: + inputsizes.append(dbtype) + try: + self.cursor.setinputsizes(*inputsizes) + except Exception, e: + self._connection._handle_dbapi_exception(e, None, None, None) + raise else: inputsizes = {} - for params in plist[0:1]: - for key in params.keys(): - typeengine = params.get_type(key) - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: - inputsizes[key] = dbtype - self.cursor.setinputsizes(**inputsizes) - - def _process_defaults(self): + for key in self.compiled.bind_names.values(): + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if dbtype is not None: + inputsizes[key.encode(self.dialect.encoding)] = dbtype + try: + self.cursor.setinputsizes(**inputsizes) + except Exception, e: + self._connection._handle_dbapi_exception(e, None, None, None) + raise + + def __process_defaults(self): """generate default values for compiled insert/update statements, and generate last_inserted_ids() collection.""" - # TODO: cleanup - if self.compiled.isinsert: - if isinstance(self.compiled_parameters, list): - plist = self.compiled_parameters - else: - plist = [self.compiled_parameters] + if self.executemany: + if len(self.compiled.prefetch): + drunner = self.dialect.defaultrunner(self) + params = self.compiled_parameters + for param in params: + # assign each dict of params to self.compiled_parameters; + # this allows user-defined default generators to access the full + # set of bind params for the row + self.compiled_parameters = param + for c in self.compiled.prefetch: + if self.isinsert: + val = drunner.get_column_default(c) + else: + val = drunner.get_column_onupdate(c) + if val is not None: + param[c.key] = val + self.compiled_parameters = params + + else: + compiled_parameters = self.compiled_parameters[0] drunner = self.dialect.defaultrunner(self) - self._lastrow_has_defaults = False - for param in plist: - last_inserted_ids = [] - # check the "default" status of each column in the table - for c in self.compiled.statement.table.c: - # check if it will be populated by a SQL clause - we'll need that - # after execution. - if c in self.compiled.inline_params: - self._lastrow_has_defaults = True - if c.primary_key: - last_inserted_ids.append(None) - # check if its not present at all. see if theres a default - # and fire it off, and add to bind parameters. if - # its a pk, add the value to our last_inserted_ids list, - # or, if its a SQL-side default, let it fire off on the DB side, but we'll need - # the SQL-generated value after execution. - elif not c.key in param or param.get_original(c.key) is None: - if isinstance(c.default, schema.PassiveDefault): - self._lastrow_has_defaults = True - newid = drunner.get_column_default(c) - if newid is not None: - param.set_value(c.key, newid) - if c.primary_key: - last_inserted_ids.append(param.get_processed(c.key)) - elif c.primary_key: - last_inserted_ids.append(None) - # its an explicitly passed pk value - add it to - # our last_inserted_ids list. - elif c.primary_key: - last_inserted_ids.append(param.get_processed(c.key)) - # TODO: we arent accounting for executemany() situations - # here (hard to do since lastrowid doesnt support it either) - self._last_inserted_ids = last_inserted_ids - self._last_inserted_params = param - elif self.compiled.isupdate: - if isinstance(self.compiled_parameters, list): - plist = self.compiled_parameters + + for c in self.compiled.prefetch: + if self.isinsert: + val = drunner.get_column_default(c) + else: + val = drunner.get_column_onupdate(c) + + if val is not None: + compiled_parameters[c.key] = val + + if self.isinsert: + self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key] + self._last_inserted_params = compiled_parameters else: - plist = [self.compiled_parameters] - drunner = self.dialect.defaultrunner(self) - self._lastrow_has_defaults = False - for param in plist: - # check the "onupdate" status of each column in the table - for c in self.compiled.statement.table.c: - # it will be populated by a SQL clause - we'll need that - # after execution. - if c in self.compiled.inline_params: - pass - # its not in the bind parameters, and theres an "onupdate" defined for the column; - # execute it and add to bind params - elif c.onupdate is not None and (not c.key in param or param.get_original(c.key) is None): - value = drunner.get_column_onupdate(c) - if value is not None: - param.set_value(c.key, value) - self._last_updated_params = param + self._last_updated_params = compiled_parameters + + self.postfetch_cols = self.compiled.postfetch + self.prefetch_cols = self.compiled.prefetch \ No newline at end of file diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 0c59ee8ebf..d4a0ad8418 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -1,10 +1,13 @@ -"""Define different strategies for creating new instances of sql.Engine. +"""Strategies for creating new instances of Engine types. -By default there are two, one which is the "thread-local" strategy, -one which is the "plain" strategy. +These are semi-private implementation classes which +provide the underlying behavior for the "strategy" keyword argument +available on [sqlalchemy.engine#create_engine()]. +Current available options are ``plain``, ``threadlocal``, and +``mock``. -New strategies can be added via constructing a new EngineStrategy -object which will add itself to the list of available strategies. +New strategies can be added via new ``EngineStrategy`` +classes. """ @@ -15,9 +18,10 @@ from sqlalchemy import pool as poollib strategies = {} class EngineStrategy(object): - """Define a function that receives input arguments and produces an - instance of sql.Engine, typically an instance - sqlalchemy.engine.base.Engine or a subclass. + """An adaptor that processes input arguements and produces an Engine. + + Provides a ``create`` method that receives input arguments and + produces an instance of base.Engine or a subclass. """ def __init__(self, name): @@ -30,11 +34,13 @@ class EngineStrategy(object): strategies[self.name] = self def create(self, *args, **kwargs): - """Given arguments, returns a new sql.Engine instance.""" + """Given arguments, returns a new Engine instance.""" raise NotImplementedError() class DefaultEngineStrategy(EngineStrategy): + """Base class for built-in stratgies.""" + def create(self, name_or_url, **kwargs): # create url.URL object u = url.make_url(name_or_url) @@ -54,9 +60,9 @@ class DefaultEngineStrategy(EngineStrategy): if k in kwargs: dbapi_args[k] = kwargs.pop(k) dbapi = dialect_cls.dbapi(**dbapi_args) - + dialect_args['dbapi'] = dbapi - + # create dialect dialect = dialect_cls(**dialect_args) @@ -71,18 +77,24 @@ class DefaultEngineStrategy(EngineStrategy): try: return dbapi.connect(*cargs, **cparams) except Exception, e: - raise exceptions.DBAPIError("Connection failed", e) + raise exceptions.DBAPIError.instance(None, None, e) creator = kwargs.pop('creator', connect) - poolclass = kwargs.pop('poolclass', getattr(dialect_cls, 'poolclass', poollib.QueuePool)) + poolclass = (kwargs.pop('poolclass', None) or + getattr(dialect_cls, 'poolclass', poollib.QueuePool)) pool_args = {} - # consume pool arguments from kwargs, translating a few of the arguments + # consume pool arguments from kwargs, translating a few of + # the arguments + translate = {'echo': 'echo_pool', + 'timeout': 'pool_timeout', + 'recycle': 'pool_recycle', + 'use_threadlocal':'pool_threadlocal'} for k in util.get_cls_kwargs(poolclass): - tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k) + tk = translate.get(k, k) if tk in kwargs: pool_args[k] = kwargs.pop(tk) - pool_args['use_threadlocal'] = self.pool_threadlocal() + pool_args.setdefault('use_threadlocal', self.pool_threadlocal()) pool = poolclass(creator, **pool_args) else: if isinstance(pool, poollib._DBProxy): @@ -98,9 +110,15 @@ class DefaultEngineStrategy(EngineStrategy): engine_args[k] = kwargs.pop(k) # all kwargs should be consumed - if len(kwargs): - raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s. Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__)) - + if kwargs: + raise TypeError( + "Invalid argument(s) %s sent to create_engine(), " + "using configuration %s/%s/%s. Please check that the " + "keyword arguments are appropriate for this combination " + "of components." % (','.join(["'%s'" % k for k in kwargs]), + dialect.__class__.__name__, + pool.__class__.__name__, + engineclass.__name__)) return engineclass(pool, dialect, u, **engine_args) def pool_threadlocal(self): @@ -110,11 +128,13 @@ class DefaultEngineStrategy(EngineStrategy): raise NotImplementedError() class PlainEngineStrategy(DefaultEngineStrategy): + """Strategy for configuring a regular Engine.""" + def __init__(self): DefaultEngineStrategy.__init__(self, 'plain') def pool_threadlocal(self): - return False + return True def get_engine_cls(self): return base.Engine @@ -122,6 +142,8 @@ class PlainEngineStrategy(DefaultEngineStrategy): PlainEngineStrategy() class ThreadLocalEngineStrategy(DefaultEngineStrategy): + """Strategy for configuring an Engine with thredlocal behavior.""" + def __init__(self): DefaultEngineStrategy.__init__(self, 'threadlocal') @@ -135,11 +157,15 @@ ThreadLocalEngineStrategy() class MockEngineStrategy(EngineStrategy): - """Produces a single Connection object which dispatches statement executions - to a passed-in function""" + """Strategy for configuring an Engine-like object with mocked execution. + + Produces a single mock Connectable object which dispatches + statement execution to a passed-in function. + """ + def __init__(self): EngineStrategy.__init__(self, 'mock') - + def create(self, name_or_url, executor, **kwargs): # create url.URL object u = url.make_url(name_or_url) @@ -164,20 +190,21 @@ class MockEngineStrategy(EngineStrategy): engine = property(lambda s: s) dialect = property(lambda s:s._dialect) - + def contextual_connect(self, **kwargs): return self def compiler(self, statement, parameters, **kwargs): - return self._dialect.compiler(statement, parameters, engine=self, **kwargs) + return self._dialect.compiler( + statement, parameters, engine=self, **kwargs) def create(self, entity, **kwargs): kwargs['checkfirst'] = False - entity.accept_visitor(self.dialect.schemagenerator(self, **kwargs)) + self.dialect.schemagenerator(self.dialect ,self, **kwargs).traverse(entity) def drop(self, entity, **kwargs): kwargs['checkfirst'] = False - entity.accept_visitor(self.dialect.schemadropper(self, **kwargs)) + self.dialect.schemadropper(self.dialect, self, **kwargs).traverse(entity) def execute(self, object, *multiparams, **params): raise NotImplementedError() diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index b6ba54ea58..e4b2859dc5 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -1,13 +1,13 @@ -from sqlalchemy import util -from sqlalchemy.engine import base - -"""Provide a thread-local transactional wrapper around the root Engine class. +"""Provides a thread-local transactional wrapper around the root Engine class. -Multiple calls to engine.connect() will return the same connection for -the same thread. also provides begin/commit methods on the engine -itself which correspond to a thread-local transaction. +The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag +with [sqlalchemy.engine#create_engine()]. This module is semi-private and is +invoked automatically when the threadlocal engine strategy is used. """ +from sqlalchemy import util +from sqlalchemy.engine import base + class TLSession(object): def __init__(self, engine): self.engine = engine @@ -17,7 +17,7 @@ class TLSession(object): try: return self.__transaction._increment_connect() except AttributeError: - return TLConnection(self, close_with_result=close_with_result) + return TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result) def reset(self): try: @@ -28,127 +28,173 @@ class TLSession(object): pass self.__tcount = 0 + def _conn_closed(self): + if self.__tcount == 1: + self.__trans._trans.rollback() + self.reset() + def in_transaction(self): return self.__tcount > 0 - def begin(self): + def prepare(self): + if self.__tcount == 1: + self.__trans._trans.prepare() + + def begin_twophase(self, xid=None): if self.__tcount == 0: self.__transaction = self.get_connection() - self.__trans = self.__transaction._begin() + self.__trans = self.__transaction._begin_twophase(xid=xid) + self.__tcount += 1 + return self.__trans + + def begin(self, **kwargs): + if self.__tcount == 0: + self.__transaction = self.get_connection() + self.__trans = self.__transaction._begin(**kwargs) self.__tcount += 1 return self.__trans def rollback(self): if self.__tcount > 0: try: - self.__trans._rollback_impl() + self.__trans._trans.rollback() finally: self.reset() def commit(self): if self.__tcount == 1: try: - self.__trans._commit_impl() + self.__trans._trans.commit() finally: self.reset() elif self.__tcount > 1: self.__tcount -= 1 - + + def close(self): + if self.__tcount == 1: + self.rollback() + elif self.__tcount > 1: + self.__tcount -= 1 + def is_begun(self): return self.__tcount > 0 + class TLConnection(base.Connection): - def __init__(self, session, close_with_result): - base.Connection.__init__(self, session.engine, close_with_result=close_with_result) + def __init__(self, session, connection, close_with_result): + base.Connection.__init__(self, session.engine, connection, close_with_result=close_with_result) self.__session = session self.__opencount = 1 - session = property(lambda s:s.__session) + def session(self): + return self.__session + session = property(session) def _increment_connect(self): self.__opencount += 1 return self - def _begin(self): - return TLTransaction(self) + def _begin(self, **kwargs): + return TLTransaction( + super(TLConnection, self).begin(**kwargs), self.__session) + + def _begin_twophase(self, xid=None): + return TLTransaction( + super(TLConnection, self).begin_twophase(xid=xid), self.__session) def in_transaction(self): return self.session.in_transaction() - def begin(self): - return self.session.begin() + def begin(self, **kwargs): + return self.session.begin(**kwargs) + + def begin_twophase(self, xid=None): + return self.session.begin_twophase(xid=xid) def close(self): if self.__opencount == 1: base.Connection.close(self) + self.__session._conn_closed() self.__opencount -= 1 def _force_close(self): self.__opencount = 0 base.Connection.close(self) -class TLTransaction(base.RootTransaction): - def _commit_impl(self): - base.Transaction.commit(self) - def _rollback_impl(self): - base.Transaction.rollback(self) +class TLTransaction(base.Transaction): + def __init__(self, trans, session): + self._trans = trans + self._session = session - def commit(self): - self.connection.session.commit() + def connection(self): + return self._trans.connection + connection = property(connection) + + def is_active(self): + return self._trans.is_active + is_active = property(is_active) def rollback(self): - self.connection.session.rollback() + self._session.rollback() + + def prepare(self): + self._session.prepare() + + def commit(self): + self._session.commit() + + def close(self): + self._session.close() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self._trans.__exit__(type, value, traceback) + class TLEngine(base.Engine): """An Engine that includes support for thread-local managed transactions. - This engine is better suited to be used with threadlocal Pool - object. + The TLEngine relies upon its Pool having "threadlocal" behavior, + so that once a connection is checked out for the current thread, + you get that same connection repeatedly. """ def __init__(self, *args, **kwargs): - """The TLEngine relies upon the Pool having - "threadlocal" behavior, so that once a connection is checked out - for the current thread, you get that same connection - repeatedly. - """ + """Construct a new TLEngine.""" super(TLEngine, self).__init__(*args, **kwargs) self.context = util.ThreadLocal() - def raw_connection(self): - """Return a DBAPI connection.""" - - return self.pool.connect() - - def connect(self, **kwargs): - """Return a Connection that is not thread-locally scoped. - - This is the equivalent to calling ``connect()`` on a - ComposedSQLEngine. - """ - - return base.Connection(self, self.pool.unique_connection()) - - def _session(self): + def session(self): + "Returns the current thread's TLSession" if not hasattr(self.context, 'session'): self.context.session = TLSession(self) return self.context.session - session = property(_session, doc="returns the current thread's TLSession") + session = property(session) def contextual_connect(self, **kwargs): """Return a TLConnection which is thread-locally scoped.""" return self.session.get_connection(**kwargs) - def begin(self): - return self.session.begin() + def begin_twophase(self, **kwargs): + return self.session.begin_twophase(**kwargs) + + def begin(self, **kwargs): + return self.session.begin(**kwargs) + def prepare(self): + self.session.prepare() + def commit(self): self.session.commit() def rollback(self): self.session.rollback() + def __repr__(self): + return 'TLEngine(%s)' % str(self.url) diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 1da76d7b2a..7364f0227c 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -1,4 +1,10 @@ -"""Provide the URL object as well as the make_url parsing function.""" +"""Provides the [sqlalchemy.engine.url#URL] class which encapsulates +information about a database connection specification. + +The URL object is created automatically when [sqlalchemy.engine#create_engine()] is called +with a string argument; alternatively, the URL is a public-facing construct which can +be used directly and is also accepted directly by ``create_engine()``. +""" import re, cgi, sys, urllib from sqlalchemy import exceptions @@ -7,15 +13,16 @@ from sqlalchemy import exceptions class URL(object): """Represent the components of a URL used to connect to a database. - This object is suitable to be passed directly to a ``create_engine()`` - call. The fields of the URL are parsed from a string by the - ``module-level make_url()`` function. the string format of the URL is - an RFC-1738-style string. + This object is suitable to be passed directly to a + ``create_engine()`` call. The fields of the URL are parsed from a + string by the ``module-level make_url()`` function. the string + format of the URL is an RFC-1738-style string. Attributes on URL include: drivername - The name of the database backend. this name will correspond to a module in sqlalchemy/databases + the name of the database backend. This name will correspond to + a module in sqlalchemy/databases or a third party plug-in. username The user name for the connection. @@ -33,7 +40,8 @@ class URL(object): The database. query - A dictionary containing key/value pairs representing the URL's query string. + A dictionary containing key/value pairs representing the URL's + query string. """ def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None, query=None): @@ -61,57 +69,75 @@ class URL(object): s += ':' + str(self.port) if self.database is not None: s += '/' + self.database - if len(self.query): + if self.query: keys = self.query.keys() keys.sort() s += '?' + "&".join(["%s=%s" % (k, self.query[k]) for k in keys]) return s - + + def __eq__(self, other): + return \ + isinstance(other, URL) and \ + self.drivername == other.drivername and \ + self.username == other.username and \ + self.password == other.password and \ + self.host == other.host and \ + self.database == other.database and \ + self.query == other.query + def get_dialect(self): """Return the SQLAlchemy database dialect class corresponding to this URL's driver name.""" - dialect=None - if self.drivername == 'ansi': - import sqlalchemy.ansisql - return sqlalchemy.ansisql.dialect - + try: - module=getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) - dialect=module.dialect + module = getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) + return module.dialect except ImportError: if sys.exc_info()[2].tb_next is None: import pkg_resources for res in pkg_resources.iter_entry_points('sqlalchemy.databases'): - if res.name==self.drivername: - dialect=res.load() - else: - raise - if dialect is not None: - return dialect - raise ImportError('unknown database %r' % self.drivername) + if res.name == self.drivername: + return res.load() + raise - def translate_connect_args(self, names): - """Translate this URL's attributes into a dictionary of connection arguments. - - Given a list of argument names corresponding to the URL - attributes (`host`, `database`, `username`, `password`, - `port`), will assemble the attribute values of this URL into - the dictionary using the given names. + def translate_connect_args(self, names=[], **kw): + """Translate url attributes into a dictionary of connection arguments. + + Returns attributes of this url (`host`, `database`, `username`, + `password`, `port`) as a plain dictionary. The attribute names are + used as the keys by default. Unset or false attributes are omitted + from the final dictionary. + + \**kw + Optional, alternate key names for url attributes:: + + # return 'username' as 'user' + username='user' + + # omit 'database' + database=None + + names + Deprecated. A list of key names. Equivalent to the keyword + usage, must be provided in the order above. """ - a = {} + translated = {} attribute_names = ['host', 'database', 'username', 'password', 'port'] - for n in names: - sname = attribute_names.pop(0) - if n is None: - continue - if getattr(self, sname, None): - a[n] = getattr(self, sname) - return a + for sname in attribute_names: + if names: + name = names.pop(0) + elif sname in kw: + name = kw[sname] + else: + name = sname + if name is not None and getattr(self, sname, False): + translated[name] = getattr(self, sname) + return translated def make_url(name_or_url): """Given a string or unicode instance, produce a new URL instance. - The given string is parsed according to the rfc1738 spec. If an + The given string is parsed according to the RFC 1738 spec. If an existing URL object is passed, just returns the object. """ @@ -122,36 +148,40 @@ def make_url(name_or_url): def _parse_rfc1738_args(name): pattern = re.compile(r''' - (\w+):// + (?P\w+):// (?: - ([^:/]*) - (?::([^/]*))? + (?P[^:/]*) + (?::(?P[^/]*))? @)? (?: - ([^/:]*) - (?::([^/]*))? + (?P[^/:]*) + (?::(?P[^/]*))? )? - (?:/(.*))? + (?:/(?P.*))? ''' , re.X) m = pattern.match(name) if m is not None: - (name, username, password, host, port, database) = m.group(1, 2, 3, 4, 5, 6) - if database is not None: - tokens = database.split(r"?", 2) - database = tokens[0] - query = (len(tokens) > 1 and dict( cgi.parse_qsl(tokens[1]) ) or None) + components = m.groupdict() + if components['database'] is not None: + tokens = components['database'].split('?', 2) + components['database'] = tokens[0] + query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None if query is not None: query = dict([(k.encode('ascii'), query[k]) for k in query]) else: query = None - opts = {'username':username,'password':password,'host':host,'port':port,'database':database, 'query':query} - if opts['password'] is not None: - opts['password'] = urllib.unquote_plus(opts['password']) - return URL(name, **opts) + components['query'] = query + + if components['password'] is not None: + components['password'] = urllib.unquote_plus(components['password']) + + name = components.pop('name') + return URL(name, **components) else: - raise exceptions.ArgumentError("Could not parse rfc1738 URL from string '%s'" % name) + raise exceptions.ArgumentError( + "Could not parse rfc1738 URL from string '%s'" % name) def _parse_keyvalue_args(name): m = re.match( r'(\w+)://(.*)', name) diff --git a/lib/sqlalchemy/exceptions.py b/lib/sqlalchemy/exceptions.py index 55c345bd72..43623df93f 100644 --- a/lib/sqlalchemy/exceptions.py +++ b/lib/sqlalchemy/exceptions.py @@ -1,30 +1,16 @@ # exceptions.py - exceptions for SQLAlchemy -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Exceptions used with SQLAlchemy. +The base exception class is SQLAlchemyError. Exceptions which are raised as a result +of DBAPI exceptions are all subclasses of [sqlalchemy.exceptions#DBAPIError].""" class SQLAlchemyError(Exception): """Generic error class.""" - pass - -class SQLError(SQLAlchemyError): - """Raised when the execution of a SQL statement fails. - - Includes accessors for the underlying exception, as well as the - SQL and bind parameters. - """ - - def __init__(self, statement, params, orig): - SQLAlchemyError.__init__(self, "(%s) %s"% (orig.__class__.__name__, str(orig))) - self.statement = statement - self.params = params - self.orig = orig - - def __str__(self): - return SQLAlchemyError.__str__(self) + " " + repr(self.statement) + " " + repr(self.params) class ArgumentError(SQLAlchemyError): """Raised for all those conditions where invalid arguments are @@ -32,30 +18,26 @@ class ArgumentError(SQLAlchemyError): construction time state errors. """ - pass class CompileError(SQLAlchemyError): """Raised when an error occurs during SQL compilation""" - - pass - + + class TimeoutError(SQLAlchemyError): """Raised when a connection pool times out on getting a connection.""" - pass class ConcurrentModificationError(SQLAlchemyError): """Raised when a concurrent modification condition is detected.""" - pass class CircularDependencyError(SQLAlchemyError): """Raised by topological sorts when a circular dependency is detected""" - pass - + + class FlushError(SQLAlchemyError): """Raised when an invalid condition is detected upon a ``flush()``.""" - pass + class InvalidRequestError(SQLAlchemyError): """SQLAlchemy was asked to do something it can't do, return @@ -64,28 +46,121 @@ class InvalidRequestError(SQLAlchemyError): This error generally corresponds to runtime state errors. """ - pass +class UnmappedColumnError(InvalidRequestError): + """A mapper was asked to return mapped information about a column + which it does not map""" class NoSuchTableError(InvalidRequestError): """SQLAlchemy was asked to load a table's definition from the database, but the table doesn't exist. """ - pass +class UnboundExecutionError(InvalidRequestError): + """SQL was attempted without a database connection to execute it on.""" class AssertionError(SQLAlchemyError): """Corresponds to internal state being detected in an invalid state.""" - pass class NoSuchColumnError(KeyError, SQLAlchemyError): """Raised by ``RowProxy`` when a nonexistent column is requested from a row.""" + +class NoReferencedTableError(InvalidRequestError): + """Raised by ``ForeignKey`` when the referred ``Table`` cannot be located.""" + +class DisconnectionError(SQLAlchemyError): + """Raised within ``Pool`` when a disconnect is detected on a raw DB-API connection. + + This error is consumed internally by a connection pool. It can be raised by + a ``PoolListener`` so that the host pool forces a disconnect. + """ - pass class DBAPIError(SQLAlchemyError): - """Something weird happened with a particular DBAPI version.""" + """Raised when the execution of a database operation fails. + + ``DBAPIError`` wraps exceptions raised by the DB-API underlying the + database operation. Driver-specific implementations of the standard + DB-API exception types are wrapped by matching sub-types of SQLAlchemy's + ``DBAPIError`` when possible. DB-API's ``Error`` type maps to + ``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note + that there is no guarantee that different DB-API implementations will + raise the same exception type for any given error condition. + + If the error-raising operation occured in the execution of a SQL + statement, that statement and its parameters will be available on + the exception object in the ``statement`` and ``params`` attributes. + + The wrapped exception object is available in the ``orig`` attribute. + Its type and properties are DB-API implementation specific. + """ - def __init__(self, message, orig): - SQLAlchemyError.__init__(self, "(%s) (%s) %s"% (message, orig.__class__.__name__, str(orig))) + def instance(cls, statement, params, orig, connection_invalidated=False): + # Don't ever wrap these, just return them directly as if + # DBAPIError didn't exist. + if isinstance(orig, (KeyboardInterrupt, SystemExit)): + return orig + + if orig is not None: + name, glob = orig.__class__.__name__, globals() + if name in glob and issubclass(glob[name], DBAPIError): + cls = glob[name] + + return cls(statement, params, orig, connection_invalidated) + instance = classmethod(instance) + + def __init__(self, statement, params, orig, connection_invalidated=False): + try: + text = str(orig) + except (KeyboardInterrupt, SystemExit): + raise + except Exception, e: + text = 'Error in str() of DB-API-generated exception: ' + str(e) + SQLAlchemyError.__init__( + self, "(%s) %s" % (orig.__class__.__name__, text)) + self.statement = statement + self.params = params self.orig = orig + self.connection_invalidated = connection_invalidated + + def __str__(self): + return ' '.join([SQLAlchemyError.__str__(self), + repr(self.statement), repr(self.params)]) + + +# As of 0.4, SQLError is now DBAPIError +SQLError = DBAPIError + +class InterfaceError(DBAPIError): + """Wraps a DB-API InterfaceError.""" + +class DatabaseError(DBAPIError): + """Wraps a DB-API DatabaseError.""" + +class DataError(DatabaseError): + """Wraps a DB-API DataError.""" + +class OperationalError(DatabaseError): + """Wraps a DB-API OperationalError.""" + +class IntegrityError(DatabaseError): + """Wraps a DB-API IntegrityError.""" + +class InternalError(DatabaseError): + """Wraps a DB-API InternalError.""" + +class ProgrammingError(DatabaseError): + """Wraps a DB-API ProgrammingError.""" + +class NotSupportedError(DatabaseError): + """Wraps a DB-API NotSupportedError.""" + +# Warnings +class SADeprecationWarning(DeprecationWarning): + """Issued once per usage of a deprecated API.""" + +class SAPendingDeprecationWarning(PendingDeprecationWarning): + """Issued once per usage of a deprecated API.""" + +class SAWarning(RuntimeWarning): + """Issued at runtime.""" diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py index fa32a5fc31..02f4b5b356 100644 --- a/lib/sqlalchemy/ext/activemapper.py +++ b/lib/sqlalchemy/ext/activemapper.py @@ -1,11 +1,9 @@ from sqlalchemy import ThreadLocalMetaData, util, Integer from sqlalchemy import Table, Column, ForeignKey -from sqlalchemy.orm import class_mapper, relation, create_session +from sqlalchemy.orm import class_mapper, relation, scoped_session +from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.sessioncontext import SessionContext -from sqlalchemy.ext.assignmapper import assign_mapper from sqlalchemy.orm import backref as create_backref -import sqlalchemy import inspect import sys @@ -14,20 +12,8 @@ import sys # the "proxy" to the database engine... this can be swapped out at runtime # metadata = ThreadLocalMetaData() - -try: - objectstore = sqlalchemy.objectstore -except AttributeError: - # thread local SessionContext - class Objectstore(object): - def __init__(self, *args, **kwargs): - self.context = SessionContext(*args, **kwargs) - def __getattr__(self, name): - return getattr(self.context.current, name) - session = property(lambda s:s.context.current) - - objectstore = Objectstore(create_session) - +Objectstore = scoped_session +objectstore = scoped_session(sessionmaker(autoflush=True, transactional=False)) # # declarative column declaration - this is so that we can infer the colname @@ -50,7 +36,7 @@ class column(object): # class relationship(object): def __init__(self, classname, colname=None, backref=None, private=False, - lazy=True, uselist=True, secondary=None, order_by=False): + lazy=True, uselist=True, secondary=None, order_by=False, viewonly=False): self.classname = classname self.colname = colname self.backref = backref @@ -59,6 +45,7 @@ class relationship(object): self.uselist = uselist self.secondary = secondary self.order_by = order_by + self.viewonly = viewonly def process(self, klass, propname, relations): relclass = ActiveMapperMeta.classes[self.classname] @@ -79,7 +66,8 @@ class relationship(object): private=self.private, lazy=self.lazy, uselist=self.uselist, - order_by=self.order_by) + order_by=self.order_by, + viewonly=self.viewonly) def create_backref(self, klass): if self.backref is None: @@ -88,7 +76,8 @@ class relationship(object): relclass = ActiveMapperMeta.classes[self.classname] if klass.__name__ == self.classname: - br_fkey = getattr(relclass.c, self.colname) + class_mapper(relclass).compile() + br_fkey = relclass.c[self.colname] else: br_fkey = None @@ -96,17 +85,14 @@ class relationship(object): class one_to_many(relationship): - def __init__(self, classname, colname=None, backref=None, private=False, - lazy=True, order_by=False): - relationship.__init__(self, classname, colname, backref, private, - lazy, uselist=True, order_by=order_by) - + def __init__(self, *args, **kwargs): + kwargs['uselist'] = True + relationship.__init__(self, *args, **kwargs) class one_to_one(relationship): - def __init__(self, classname, colname=None, backref=None, private=False, - lazy=True, order_by=False): - relationship.__init__(self, classname, colname, backref, private, - lazy, uselist=False, order_by=order_by) + def __init__(self, *args, **kwargs): + kwargs['uselist'] = False + relationship.__init__(self, *args, **kwargs) def create_backref(self, klass): if self.backref is None: @@ -156,9 +142,10 @@ def process_relationships(klass, was_deferred=False): # not able to find any of the related tables if not defer: for col in klass.columns: - if col.foreign_key is not None: + if col.foreign_keys: found = False - table_name = col.foreign_key._colspec.rsplit('.', 1)[0] + cn = col.foreign_keys[0]._colspec + table_name = cn[:cn.rindex('.')] for other_klass in ActiveMapperMeta.classes.values(): if other_klass.table.fullname.lower() == table_name.lower(): found = True @@ -277,10 +264,10 @@ class ActiveMapperMeta(type): # check for inheritence if hasattr(bases[0], "mapping"): cls._base_mapper= bases[0].mapper - assign_mapper(objectstore.context, cls, cls.table, + cls.mapper = objectstore.mapper(cls, cls.table, inherits=cls._base_mapper, version_id_col=version_id_col_object) else: - assign_mapper(objectstore.context, cls, cls.table, version_id_col=version_id_col_object) + cls.mapper = objectstore.mapper(cls, cls.table, version_id_col=version_id_col_object) cls.relations = relations ActiveMapperMeta.classes[clsname] = cls diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py index 2380417020..5a28fbe68b 100644 --- a/lib/sqlalchemy/ext/assignmapper.py +++ b/lib/sqlalchemy/ext/assignmapper.py @@ -1,7 +1,19 @@ from sqlalchemy import util, exceptions import types -from sqlalchemy.orm import mapper - +from sqlalchemy.orm import mapper, Query + +def _monkeypatch_query_method(name, ctx, class_): + def do(self, *args, **kwargs): + query = Query(class_, session=ctx.current) + util.warn_deprecated('Query methods on the class are deprecated; use %s.query.%s instead' % (class_.__name__, name)) + return getattr(query, name)(*args, **kwargs) + try: + do.__name__ = name + except: + pass + if not hasattr(class_, name): + setattr(class_, name, classmethod(do)) + def _monkeypatch_session_method(name, ctx, class_): def do(self, *args, **kwargs): session = ctx.current @@ -10,9 +22,9 @@ def _monkeypatch_session_method(name, ctx, class_): do.__name__ = name except: pass - if not hasattr(class_, name): + if not hasattr(class_, name): setattr(class_, name, do) - + def assign_mapper(ctx, class_, *args, **kwargs): extension = kwargs.pop('extension', None) if extension is not None: @@ -22,29 +34,39 @@ def assign_mapper(ctx, class_, *args, **kwargs): extension = ctx.mapper_extension validate = kwargs.pop('validate', False) - + if not isinstance(getattr(class_, '__init__'), types.MethodType): def __init__(self, **kwargs): for key, value in kwargs.items(): if validate: - if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False): - raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) + if not self.mapper.get_property(key, + resolve_synonyms=False, + raiseerr=False): + raise exceptions.ArgumentError( + "Invalid __init__ argument: '%s'" % key) setattr(self, key, value) class_.__init__ = __init__ - + class query(object): def __getattr__(self, key): return getattr(ctx.current.query(class_), key) def __call__(self): return ctx.current.query(class_) - if not hasattr(class_, 'query'): + if not hasattr(class_, 'query'): class_.query = query() - - for name in ['refresh', 'expire', 'delete', 'expunge', 'update']: + + for name in ('get', 'filter', 'filter_by', 'select', 'select_by', + 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', + 'get_by', 'join_to', 'join_via', 'count', 'count_by', + 'options', 'instances'): + _monkeypatch_query_method(name, ctx, class_) + for name in ('refresh', 'expire', 'delete', 'expunge', 'update'): _monkeypatch_session_method(name, ctx, class_) m = mapper(class_, extension=extension, *args, **kwargs) class_.mapper = m return m +assign_mapper = util.deprecated( + "assign_mapper is deprecated. Use scoped_session() instead.")(assign_mapper) diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 2dd8072228..d878f7b9b0 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -6,10 +6,13 @@ transparent proxied access to the endpoint of an association object. See the example ``examples/association/proxied_association.py``. """ -import weakref, itertools -import sqlalchemy.exceptions as exceptions -import sqlalchemy.orm as orm -import sqlalchemy.util as util +import itertools +import weakref +from sqlalchemy import exceptions +from sqlalchemy import orm +from sqlalchemy import util +from sqlalchemy.orm import collections + def association_proxy(targetcollection, attr, **kw): """Convenience function for use in mapped classes. Implements a Python @@ -60,7 +63,7 @@ class AssociationProxy(object): on an object.""" def __init__(self, targetcollection, attr, creator=None, - proxy_factory=None, proxy_bulk_set=None): + getset_factory=None, proxy_factory=None, proxy_bulk_set=None): """Arguments are: targetcollection @@ -83,6 +86,17 @@ class AssociationProxy(object): If you want to construct instances differently, supply a 'creator' function that takes arguments as above and returns instances. + getset_factory + Optional. Proxied attribute access is automatically handled + by routines that get and set values based on the `attr` argument + for this proxy. + + If you would like to customize this behavior, you may supply a + `getset_factory` callable that produces a tuple of `getter` and + `setter` functions. The factory is called with two arguments, + the abstract type of the underlying collection and this proxy + instance. + proxy_factory Optional. The type of collection to emulate is determined by sniffing the target collection. If your collection type can't be @@ -94,10 +108,11 @@ class AssociationProxy(object): Optional, use with proxy_factory. See the _set() method for details. """ - + self.target_collection = targetcollection # backwards compat name... self.value_attr = attr self.creator = creator + self.getset_factory = getset_factory self.proxy_factory = proxy_factory self.proxy_bulk_set = proxy_bulk_set @@ -128,27 +143,39 @@ class AssociationProxy(object): "scope") return getattr(obj, target) return lazy_collection - + def __get__(self, obj, class_): + if self.owning_class is None: + self.owning_class = class_ and class_ or type(obj) if obj is None: - self.owning_class = class_ - return + return None elif self.scalar is None: self.scalar = self._target_is_scalar() + if self.scalar: + self._initialize_scalar_accessors() if self.scalar: - return getattr(getattr(obj, self.target_collection), self.value_attr) + return self._scalar_get(getattr(obj, self.target_collection)) else: try: - return getattr(obj, self.key) + # If the owning instance is reborn (orm session resurrect, + # etc.), refresh the proxy cache. + creator_id, proxy = getattr(obj, self.key) + if id(obj) == creator_id: + return proxy except AttributeError: - proxy = self._new(self._lazy_collection(weakref.ref(obj))) - setattr(obj, self.key, proxy) - return proxy + pass + proxy = self._new(self._lazy_collection(weakref.ref(obj))) + setattr(obj, self.key, (id(obj), proxy)) + return proxy def __set__(self, obj, values): + if self.owning_class is None: + self.owning_class = type(obj) if self.scalar is None: self.scalar = self._target_is_scalar() + if self.scalar: + self._initialize_scalar_accessors() if self.scalar: creator = self.creator and self.creator or self.target_class @@ -156,15 +183,34 @@ class AssociationProxy(object): if target is None: setattr(obj, self.target_collection, creator(values)) else: - setattr(target, self.value_attr, values) + self._scalar_set(target, values) else: proxy = self.__get__(obj, None) - proxy.clear() - self._set(proxy, values) + if proxy is not values: + proxy.clear() + self._set(proxy, values) def __delete__(self, obj): + if self.owning_class is None: + self.owning_class = type(obj) delattr(obj, self.key) + def _initialize_scalar_accessors(self): + if self.getset_factory: + get, set = self.getset_factory(None, self) + else: + get, set = self._default_getset(None) + self._scalar_get, self._scalar_set = get, set + + def _default_getset(self, collection_class): + attr = self.value_attr + getter = util.attrgetter(attr) + if collection_class is dict: + setter = lambda o, k, v: setattr(o, attr, v) + else: + setter = lambda o, v: setattr(o, attr, v) + return getter, setter + def _new(self, lazy_collection): creator = self.creator and self.creator or self.target_class self.collection_class = util.duck_type_collection(lazy_collection()) @@ -172,15 +218,15 @@ class AssociationProxy(object): if self.proxy_factory: return self.proxy_factory(lazy_collection, creator, self.value_attr) - value_attr = self.value_attr - getter = lambda o: getattr(o, value_attr) - setter = lambda o, v: setattr(o, value_attr, v) - + if self.getset_factory: + getter, setter = self.getset_factory(self.collection_class, self) + else: + getter, setter = self._default_getset(self.collection_class) + if self.collection_class is list: return _AssociationList(lazy_collection, creator, getter, setter) elif self.collection_class is dict: - kv_setter = lambda o, k, v: setattr(o, value_attr, v) - return _AssociationDict(lazy_collection, creator, getter, kv_setter) + return _AssociationDict(lazy_collection, creator, getter, setter) elif self.collection_class is util.Set: return _AssociationSet(lazy_collection, creator, getter, setter) else: @@ -214,7 +260,7 @@ class _AssociationList(object): lazy_collection A callable returning a list-based collection of entities (usually an object attribute managed by a SQLAlchemy relation()) - + creator A function that creates new target entities. Given one parameter: value. The assertion is assumed: @@ -258,7 +304,7 @@ class _AssociationList(object): def __getitem__(self, index): return self._get(self.col[index]) - + def __setitem__(self, index, value): if not isinstance(index, slice): self._set(self.col[index], value) @@ -293,6 +339,7 @@ class _AssociationList(object): def __contains__(self, value): for member in self.col: + # testlib.pragma exempt:__eq__ if self._get(member) == value: return True return False @@ -367,16 +414,53 @@ class _AssociationList(object): def __ge__(self, other): return list(self) >= other def __cmp__(self, other): return cmp(list(self), other) + def __add__(self, iterable): + try: + other = list(iterable) + except TypeError: + return NotImplemented + return list(self) + other + + def __radd__(self, iterable): + try: + other = list(iterable) + except TypeError: + return NotImplemented + return other + list(self) + + def __mul__(self, n): + if not isinstance(n, int): + return NotImplemented + return list(self) * n + __rmul__ = __mul__ + + def __iadd__(self, iterable): + self.extend(iterable) + return self + + def __imul__(self, n): + # unlike a regular list *=, proxied __imul__ will generate unique + # backing objects for each copy. *= on proxied lists is a bit of + # a stretch anyhow, and this interpretation of the __imul__ contract + # is more plausibly useful than copying the backing objects. + if not isinstance(n, int): + return NotImplemented + if n == 0: + self.clear() + elif n > 1: + self.extend(list(self) * (n - 1)) + return self + def copy(self): return list(self) def __repr__(self): return repr(list(self)) - def hash(self): + def __hash__(self): raise TypeError("%s objects are unhashable" % type(self).__name__) -_NotProvided = object() +_NotProvided = util.symbol('_NotProvided') class _AssociationDict(object): """Generic proxying list which proxies dict operations to a another dict, converting association objects to and from a simplified value. @@ -387,7 +471,7 @@ class _AssociationDict(object): lazy_collection A callable returning a dict-based collection of entities (usually an object attribute managed by a SQLAlchemy relation()) - + creator A function that creates new target entities. Given two parameters: key and value. The assertion is assumed: @@ -429,7 +513,7 @@ class _AssociationDict(object): def __getitem__(self, key): return self._get(self.col[key]) - + def __setitem__(self, key, value): if key in self.col: self._set(self.col[key], key, value) @@ -440,6 +524,7 @@ class _AssociationDict(object): del self.col[key] def __contains__(self, key): + # testlib.pragma exempt:__hash__ return key in self.col has_key = __contains__ @@ -502,7 +587,7 @@ class _AssociationDict(object): def popitem(self): item = self.col.popitem() return (item[0], self._get(item[1])) - + def update(self, *a, **kw): if len(a) > 1: raise TypeError('update expected at most 1 arguments, got %i' % @@ -521,7 +606,7 @@ class _AssociationDict(object): def copy(self): return dict(self.items()) - def hash(self): + def __hash__(self): raise TypeError("%s objects are unhashable" % type(self).__name__) class _AssociationSet(object): @@ -534,7 +619,7 @@ class _AssociationSet(object): collection A callable returning a set-based collection of entities (usually an object attribute managed by a SQLAlchemy relation()) - + creator A function that creates new target entities. Given one parameter: value. The assertion is assumed: @@ -576,6 +661,7 @@ class _AssociationSet(object): def __contains__(self, value): for member in self.col: + # testlib.pragma exempt:__eq__ if self._get(member) == value: return True return False @@ -617,7 +703,12 @@ class _AssociationSet(object): for value in other: self.add(value) - __ior__ = update + def __ior__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + for value in other: + self.add(value) + return self def _set(self): return util.Set(iter(self)) @@ -636,7 +727,12 @@ class _AssociationSet(object): for value in other: self.discard(value) - __isub__ = difference_update + def __isub__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + for value in other: + self.discard(value) + return self def intersection(self, other): return util.Set(self).intersection(other) @@ -653,7 +749,18 @@ class _AssociationSet(object): for value in add: self.add(value) - __iand__ = intersection_update + def __iand__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), util.Set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self def symmetric_difference(self, other): return util.Set(self).symmetric_difference(other) @@ -670,14 +777,25 @@ class _AssociationSet(object): for value in add: self.add(value) - __ixor__ = symmetric_difference_update + def __ixor__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), util.Set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self def issubset(self, other): return util.Set(self).issubset(other) - + def issuperset(self, other): return util.Set(self).issuperset(other) - + def clear(self): self.col.clear() @@ -694,5 +812,5 @@ class _AssociationSet(object): def __repr__(self): return repr(util.Set(self)) - def hash(self): + def __hash__(self): raise TypeError("%s objects are unhashable" % type(self).__name__) diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py new file mode 100644 index 0000000000..d736736e95 --- /dev/null +++ b/lib/sqlalchemy/ext/declarative.py @@ -0,0 +1,355 @@ +"""A simple declarative layer for SQLAlchemy ORM. + +SQLAlchemy object-relational configuration involves the usage of Table, +mapper(), and class objects to define the three areas of configuration. +declarative moves these three types of configuration underneath the +individual mapped class. Regular SQLAlchemy schema and ORM constructs are +used in most cases:: + + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class SomeClass(Base): + __tablename__ = 'some_table' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + +Above, the ``declarative_base`` callable produces a new base class from +which all mapped classes inherit from. When the class definition is +completed, a new ``Table`` and ``mapper()`` have been generated, accessible +via the ``__table__`` and ``__mapper__`` attributes on the ``SomeClass`` +class. + +You may omit the names from the Column definitions. Declarative will fill +them in for you:: + + class SomeClass(Base): + __tablename__ = 'some_table' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + +Attributes may be added to the class after its construction, and they will +be added to the underlying ``Table`` and ``mapper()`` definitions as +appropriate:: + + SomeClass.data = Column('data', Unicode) + SomeClass.related = relation(RelatedInfo) + +Classes which are mapped explicitly using ``mapper()`` can interact freely +with declarative classes. + +The ``declarative_base`` base class contains a +``MetaData`` object where newly defined ``Table`` objects are collected. +This is accessed via the ``metadata`` class level accessor, so to +create tables we can say:: + + engine = create_engine('sqlite://') + Base.metadata.create_all(engine) + +The ``Engine`` created above may also be directly associated with the +declarative base class using the ``engine`` keyword argument, where it will +be associated with the underlying ``MetaData`` object and allow SQL +operations involving that metadata and its tables to make use of that +engine automatically:: + + Base = declarative_base(engine=create_engine('sqlite://')) + +Or, as ``MetaData`` allows, at any time using the ``bind`` attribute:: + + Base.metadata.bind = create_engine('sqlite://') + +The ``declarative_base`` can also receive a pre-created ``MetaData`` +object, which allows a declarative setup to be associated with an already existing traditional collection of ``Table`` objects:: + + mymetadata = MetaData() + Base = declarative_base(metadata=mymetadata) + +Relations to other classes are done in the usual way, with the added feature +that the class specified to ``relation()`` may be a string name. The "class +registry" associated with ``Base`` is used at mapper compilation time to +resolve the name into the actual class object, which is expected to have +been defined once the mapper configuration is used:: + + class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + addresses = relation("Address", backref="user") + + class Address(Base): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + +Column constructs, since they are just that, are immediately usable, as +below where we define a primary join condition on the ``Address`` class +using them:: + + class Address(Base) + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + user = relation(User, primaryjoin=user_id==User.id) + +When an explicit join condition or other configuration which depends +on multiple classes cannot be defined immediately due to some classes +not yet being available, these can be defined after all classes have +been created. Attributes which are added to the class after +its creation are associated with the Table/mapping in the same +way as if they had been defined inline:: + + User.addresses = relation(Address, primaryjoin=Address.user_id==User.id) + +Synonyms are one area where ``declarative`` needs to slightly change the +usual SQLAlchemy configurational syntax. To define a getter/setter which +proxies to an underlying attribute, use ``synonym`` with the ``descriptor`` +argument:: + + class MyClass(Base): + __tablename__ = 'sometable' + + _attr = Column('attr', String) + + def _get_attr(self): + return self._some_attr + def _set_attr(self, attr) + self._some_attr = attr + attr = synonym('_attr', descriptor=property(_get_attr, _set_attr)) + +The above synonym is then usable as an instance attribute as well as a +class-level expression construct:: + + x = MyClass() + x.attr = "some value" + session.query(MyClass).filter(MyClass.attr == 'some other value').all() + +As an alternative to ``__tablename__``, a direct ``Table`` construct may be +used. The ``Column`` objects, which in this case require their names, +will be added to the mapping just like a regular mapping to a table:: + + class MyClass(Base): + __table__ = Table('my_table', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) + +This is the preferred approach when using reflected tables, as below:: + + class MyClass(Base): + __table__ = Table('my_table', Base.metadata, autoload=True) + +Mapper arguments are specified using the ``__mapper_args__`` class variable. +Note that the column objects declared on the class are immediately usable, +as in this joined-table inheritance example:: + + class Person(Base): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + discriminator = Column(String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity':'engineer'} + id = Column(Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column(String(50)) + +For single-table inheritance, the ``__tablename__`` and ``__table__`` class +variables are optional on a class when the class inherits from another +mapped class. + +As a convenience feature, the ``declarative_base()`` sets a default +constructor on classes which takes keyword arguments, and assigns them to +the named attributes:: + + e = Engineer(primary_language='python') + +Note that ``declarative`` has no integration built in with sessions, and is +only intended as an optional syntax for the regular usage of mappers and +Table objects. A typical application setup using ``scoped_session`` might +look like:: + + engine = create_engine('postgres://scott:tiger@localhost/test') + Session = scoped_session(sessionmaker(transactional=True, autoflush=False, bind=engine)) + Base = declarative_base() + +Mapped instances then make usage of ``Session`` in the usual way. +""" + +from sqlalchemy.schema import Table, Column, MetaData +from sqlalchemy.orm import synonym as _orm_synonym, mapper, comparable_property +from sqlalchemy.orm.interfaces import MapperProperty +from sqlalchemy.orm.properties import PropertyLoader, ColumnProperty +from sqlalchemy import util, exceptions +from sqlalchemy.sql import util as sql_util + + +__all__ = ['declarative_base', 'synonym_for', 'comparable_using', + 'declared_synonym'] + + +class DeclarativeMeta(type): + def __init__(cls, classname, bases, dict_): + if '_decl_class_registry' in cls.__dict__: + return type.__init__(cls, classname, bases, dict_) + + cls._decl_class_registry[classname] = cls + our_stuff = util.OrderedDict() + for k in dict_: + value = dict_[k] + if (isinstance(value, tuple) and len(value) == 1 and + isinstance(value[0], (Column, MapperProperty))): + util.warn("Ignoring declarative-like tuple value of attribute " + "%s: possibly a copy-and-paste error with a comma " + "left at the end of the line?" % k) + continue + if not isinstance(value, (Column, MapperProperty)): + continue + prop = _deferred_relation(cls, value) + our_stuff[k] = prop + + table = None + if '__table__' not in cls.__dict__: + if '__tablename__' in cls.__dict__: + tablename = cls.__tablename__ + autoload = cls.__dict__.get('__autoload__') + if autoload: + table_kw = {'autoload': True} + else: + table_kw = {} + cols = [] + for key, c in our_stuff.iteritems(): + if isinstance(c, ColumnProperty): + for col in c.columns: + if isinstance(col, Column) and col.table is None: + _undefer_column_name(key, col) + cols.append(col) + elif isinstance(c, Column): + _undefer_column_name(key, c) + cols.append(c) + cls.__table__ = table = Table(tablename, cls.metadata, + *cols, **table_kw) + else: + table = cls.__table__ + + mapper_args = getattr(cls, '__mapper_args__', {}) + if 'inherits' not in mapper_args: + inherits = cls.__mro__[1] + inherits = cls._decl_class_registry.get(inherits.__name__, None) + if inherits: + mapper_args['inherits'] = inherits + if not mapper_args.get('concrete', False) and table: + # figure out the inherit condition with relaxed rules about nonexistent tables, + # to allow for ForeignKeys to not-yet-defined tables (since we know for sure that our parent + # table is defined within the same MetaData) + mapper_args['inherit_condition'] = sql_util.join_condition(inherits.__table__, table, ignore_nonexistent_tables=True) + + if hasattr(cls, '__mapper_cls__'): + mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) + else: + mapper_cls = mapper + cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args) + return type.__init__(cls, classname, bases, dict_) + + def __setattr__(cls, key, value): + if '__mapper__' in cls.__dict__: + if isinstance(value, Column): + _undefer_column_name(key, value) + cls.__table__.append_column(value) + cls.__mapper__.add_property(key, value) + elif isinstance(value, MapperProperty): + cls.__mapper__.add_property(key, _deferred_relation(cls, value)) + else: + type.__setattr__(cls, key, value) + else: + type.__setattr__(cls, key, value) + +def _deferred_relation(cls, prop): + if isinstance(prop, PropertyLoader) and isinstance(prop.argument, basestring): + arg = prop.argument + def return_cls(): + try: + return cls._decl_class_registry[arg] + except KeyError: + raise exceptions.InvalidRequestError("When compiling mapper %s, could not locate a declarative class named %r. Consider adding this property to the %r class after both dependent classes have been defined." % (prop.parent, arg, prop.parent.class_)) + prop.argument = return_cls + + return prop + +def declared_synonym(prop, name): + """Deprecated. Use synonym(name, descriptor=prop).""" + return _orm_synonym(name, descriptor=prop) +declared_synonym = util.deprecated(None, False)(declared_synonym) + +def synonym_for(name, map_column=False): + """Decorator, make a Python @property a query synonym for a column. + + A decorator version of [sqlalchemy.orm#synonym()]. The function being + decorated is the 'descriptor', otherwise passes its arguments through + to synonym():: + + @synonym_for('col') + @property + def prop(self): + return 'special sauce' + + The regular ``synonym()`` is also usable directly in a declarative + setting and may be convenient for read/write properties:: + + prop = synonym('col', descriptor=property(_read_prop, _write_prop)) + + """ + def decorate(fn): + return _orm_synonym(name, map_column=map_column, descriptor=fn) + return decorate + + +def comparable_using(comparator_factory): + """Decorator, allow a Python @property to be used in query criteria. + + A decorator front end to [sqlalchemy.orm#comparable_property()], passes + throgh the comparator_factory and the function being decorated:: + + @comparable_using(MyComparatorType) + @property + def prop(self): + return 'special sauce' + + The regular ``comparable_property()`` is also usable directly in a + declarative setting and may be convenient for read/write properties:: + + prop = comparable_property(MyComparatorType) + """ + def decorate(fn): + return comparable_property(comparator_factory, fn) + return decorate + +def declarative_base(engine=None, metadata=None, mapper=None): + lcl_metadata = metadata or MetaData() + if engine: + lcl_metadata.bind = engine + class Base(object): + __metaclass__ = DeclarativeMeta + metadata = lcl_metadata + if mapper: + __mapper_cls__ = mapper + _decl_class_registry = {} + def __init__(self, **kwargs): + for k in kwargs: + if not hasattr(type(self), k): + raise TypeError('%r is an invalid keyword argument for %s' % + (k, type(self).__name__)) + setattr(self, k, kwargs[k]) + return Base + +def _undefer_column_name(key, column): + if column.key is None: + column.key = key + if column.name is None: + column.name = key diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index e02990a263..e7464b0bdd 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -1,27 +1,80 @@ +"""A custom list that manages index/position information for its children. + +``orderinglist`` is a custom list collection implementation for mapped relations +that keeps an arbitrary "position" attribute on contained objects in sync with +each object's position in the Python list. + +The collection acts just like a normal Python ``list``, with the added +behavior that as you manipulate the list (via ``insert``, ``pop``, assignment, +deletion, what have you), each of the objects it contains is updated as needed +to reflect its position. This is very useful for managing ordered relations +which have a user-defined, serialized order:: + + from sqlalchemy.ext.orderinglist import ordering_list + + users = Table('users', metadata, + Column('id', Integer, primary_key=True)) + blurbs = Table('user_top_ten_list', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('users.id')), + Column('position', Integer), + Column('blurb', String(80))) + + class User(object): pass + class Blurb(object): + def __init__(self, blurb): + self.blurb = blurb + + mapper(User, users, properties={ + 'topten': relation(Blurb, collection_class=ordering_list('position'), + order_by=[blurbs.c.position]) + }) + mapper(Blurb, blurbs) + + u = User() + u.topten.append(Blurb('Number one!')) + u.topten.append(Blurb('Number two!')) + + # Like magic. + assert [blurb.position for blurb in u.topten] == [0, 1] + + # The objects will be renumbered automaticaly after any list-changing + # operation, for example an insert: + u.topten.insert(1, Blurb('I am the new Number Two.')) + + assert [blurb.position for blurb in u.topten] == [0, 1, 2] + assert u.topten[1].blurb == 'I am the new Number Two.' + assert u.topten[1].position == 1 + +Numbering and serialization are both highly configurable. See the docstrings +in this module and the main SQLAlchemy documentation for more information and +examples. + +The [sqlalchemy.ext.orderinglist#ordering_list] function is the ORM-compatible +constructor for OrderingList instances. """ -A custom list implementation for mapped relations that syncs position in a -Python list with a position attribute on the mapped objects. -""" + __all__ = [ 'ordering_list' ] def ordering_list(attr, count_from=None, **kw): - """ - Prepares an OrderingList factory for use as an argument to a - Mapper relation's 'collection_class' option. Arguments are: + """Prepares an OrderingList factory for use in mapper definitions. + + Returns an object suitable for use as an argument to a Mapper relation's + ``collection_class`` option. Arguments are: attr Name of the mapped attribute to use for storage and retrieval of ordering information count_from (optional) - Set up an integer-based ordering, starting at 'count_from'. For example, - ordering_list('pos', count_from=1) would create a 1-based list in SQL, - storing the value in the 'pos' column. Ignored if ordering_func is - supplied. + Set up an integer-based ordering, starting at ``count_from``. For + example, ``ordering_list('pos', count_from=1)`` would create a 1-based + list in SQL, storing the value in the 'pos' column. Ignored if + ``ordering_func`` is supplied. - Passes along any keyword arguments to OrderingList constructor. + Passes along any keyword arguments to ``OrderingList`` constructor. """ kw = _unsugar_count_from(count_from=count_from, **kw) @@ -50,8 +103,11 @@ def count_from_n_factory(start): return f def _unsugar_count_from(**kw): - """Keyword argument filter, prepares a simple ordering_func from - a 'count_from' argument, otherwise passes ordering_func on unchanged.""" + """Builds counting functions from keywrod arguments. + + Keyword argument filter, prepares a simple ``ordering_func`` from a + ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. + """ count_from = kw.pop('count_from', None) if kw.get('ordering_func', None) is None and count_from is not None: @@ -64,36 +120,45 @@ def _unsugar_count_from(**kw): return kw class OrderingList(list): + """A custom list that manages position information for its children. + + See the module and __init__ documentation for more details. The + ``ordering_list`` function is used to configure ``OrderingList`` + collections in ``mapper`` relation definitions. + """ + def __init__(self, ordering_attr=None, ordering_func=None, reorder_on_append=False): - """ - A 'collection_class' list implementation that syncs position in a - Python list with a position attribute on the mapped objects. + """A custom list that manages position information for its children. + + ``OrderingList`` is a ``collection_class`` list implementation that + syncs position in a Python list with a position attribute on the + mapped objects. - This implementation counts on the list starting in the proper - order, so be SURE to put an order_by on your relation. - Arguments are: + This implementation relies on the list starting in the proper order, + so be **sure** to put an ``order_by`` on your relation. ordering_attr Name of the attribute that stores the object's order in the relation. ordering_func Optional. A function that maps the position in the Python list to a - value to store in the ordering_attr. Values returned are usually + value to store in the ``ordering_attr``. Values returned are usually (but need not be!) integers. - ordering_funcs are called with two positional parameters: index of - the element in the list, and the list itself. + An ``ordering_func`` is called with two positional parameters: the + index of the element in the list, and the list itself. - If omitted, list indexes are used for the attribute values. Two - basic pre-built numbering functions are provided: 'count_from_0' and - 'count_from_1'. For more exotic examples like stepped numbering, - alphabetical and Fibonacci numbering, see the unit tests. + If omitted, Python list indexes are used for the attribute values. + Two basic pre-built numbering functions are provided in this module: + ``count_from_0`` and ``count_from_1``. For more exotic examples + like stepped numbering, alphabetical and Fibonacci numbering, see + the unit tests. reorder_on_append - Default false. When appending an object with an existing (non-None) + Default False. When appending an object with an existing (non-None) ordering value, that value will be left untouched unless - reorder_on_append is true. This is an optimization to avoid a + ``reorder_on_append`` is true. This is an optimization to avoid a variety of dangerous unexpected database writes. SQLAlchemy will add instances to the list via append() when your @@ -102,13 +167,14 @@ class OrderingList(list): '2', '3', and '4'), reorder_on_append=True would immediately renumber the items to '1', '2', '3'. If you have multiple sessions making changes, any of whom happen to load this collection even in - passing, all of the sessions would try to 'clean up' the numbering + passing, all of the sessions would try to "clean up" the numbering in their commits, possibly causing all but one to fail with a concurrent modification error. Spooky action at a distance. Recommend leaving this with the default of False, and just call - ._reorder() if you're doing append() operations with previously - ordered instances or doing housekeeping after manual sql operations. + ``_reorder()`` if you're doing ``append()`` operations with + previously ordered instances or when doing some housekeeping after + manual sql operations. """ self.ordering_attr = ordering_attr diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py index 1920b6f924..4462282542 100644 --- a/lib/sqlalchemy/ext/selectresults.py +++ b/lib/sqlalchemy/ext/selectresults.py @@ -15,7 +15,7 @@ class SelectResultsExt(orm.MapperExtension): def select(self, query, arg=None, **kwargs): if isinstance(arg, sql.FromClause) and arg.supports_execution(): - return orm.EXT_PASS + return orm.EXT_CONTINUE else: if arg is not None: query = query.filter(arg) diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py index fcbf29c3ff..5ac8acb405 100644 --- a/lib/sqlalchemy/ext/sessioncontext.py +++ b/lib/sqlalchemy/ext/sessioncontext.py @@ -1,34 +1,24 @@ -from sqlalchemy.util import ScopedRegistry -from sqlalchemy.orm import create_session, object_session, MapperExtension, EXT_PASS +from sqlalchemy.orm.scoping import ScopedSession, _ScopedExt +from sqlalchemy.util import warn_deprecated +from sqlalchemy.orm import create_session __all__ = ['SessionContext', 'SessionContextExt'] -class SessionContext(object): - """A simple wrapper for ``ScopedRegistry`` that provides a - `current` property which can be used to get, set, or remove the - session in the current scope. - By default this object provides thread-local scoping, which is the - default scope provided by sqlalchemy.util.ScopedRegistry. +class SessionContext(ScopedSession): + """Provides thread-local management of Sessions. Usage:: - engine = create_engine(...) - def session_factory(): - return Session(bind=engine) - context = SessionContext(session_factory) + context = SessionContext(sessionmaker(autoflush=True)) - s = context.current # get thread-local session - context.current = Session(bind=other_engine) # set current session - del context.current # discard the thread-local session (a new one will - # be created on the next call to context.current) """ def __init__(self, session_factory=None, scopefunc=None): + warn_deprecated("SessionContext is deprecated. Use scoped_session().") if session_factory is None: - session_factory = create_session - self.registry = ScopedRegistry(session_factory, scopefunc) - super(SessionContext, self).__init__() + session_factory=create_session + super(SessionContext, self).__init__(session_factory, scopefunc=scopefunc) def get_current(self): return self.registry() @@ -50,33 +40,11 @@ class SessionContext(object): return ext mapper_extension = property(_get_mapper_extension, - doc="""Get a mapper extension that implements `get_session` using this context.""") - - -class SessionContextExt(MapperExtension): - """A mapper extension that provides sessions to a mapper using ``SessionContext``.""" - - def __init__(self, context): - MapperExtension.__init__(self) - self.context = context - - def get_session(self): - return self.context.current - - def init_instance(self, mapper, class_, instance, args, kwargs): - session = kwargs.pop('_sa_session', self.context.current) - session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None)) - return EXT_PASS - - def init_failed(self, mapper, class_, instance, args, kwargs): - object_session(instance).expunge(instance) - return EXT_PASS - - def dispose_class(self, mapper, class_): - if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): - if class_.__init__._oldinit is not None: - class_.__init__ = class_.__init__._oldinit - else: - delattr(class_, '__init__') - - + doc="""Get a mapper extension that implements `get_session` using this context. Deprecated.""") + + +class SessionContextExt(_ScopedExt): + def __init__(self, *args, **kwargs): + warn_deprecated("SessionContextExt is deprecated. Use ScopedSession(enhance_classes=True)") + super(SessionContextExt, self).__init__(*args, **kwargs) + diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 756b5e1e73..bad9ba5a80 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -20,7 +20,7 @@ engine:: >>> from sqlalchemy.ext.sqlsoup import SqlSoup >>> db = SqlSoup('sqlite:///:memory:') -or, you can re-use an existing metadata:: +or, you can re-use an existing metadata or engine:: >>> db = SqlSoup(MetaData(e)) @@ -35,14 +35,14 @@ Loading objects Loading objects is as easy as this:: - >>> users = db.users.select() + >>> users = db.users.all() >>> users.sort() >>> users [MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0), MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1)] -Of course, letting the database do the sort is better (".c" is short for ".columns"):: +Of course, letting the database do the sort is better:: - >>> db.users.select(order_by=[db.users.c.name]) + >>> db.users.order_by(db.users.name).all() [MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1), MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)] Field access is intuitive:: @@ -55,14 +55,15 @@ WHERE clause. Let's also switch the order_by to DESC while we're at it:: >>> from sqlalchemy import or_, and_, desc - >>> where = or_(db.users.c.name=='Bhargan Basepair', db.users.c.email=='student@example.edu') - >>> db.users.select(where, order_by=[desc(db.users.c.name)]) + >>> where = or_(db.users.name=='Bhargan Basepair', db.users.email=='student@example.edu') + >>> db.users.filter(where).order_by(desc(db.users.name)).all() [MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0), MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1)] -You can also use the select...by methods if you're querying on a -single column. This allows using keyword arguments as column names:: +You can also use .first() (to retrieve only the first object from a query) or +.one() (like .first when you expect exactly one user -- it will raise an +exception if more were returned):: - >>> db.users.selectone_by(name='Bhargan Basepair') + >>> db.users.filter(db.users.name=='Bhargan Basepair').one() MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1) Since name is the primary key, this is equivalent to @@ -70,43 +71,22 @@ Since name is the primary key, this is equivalent to >>> db.users.get('Bhargan Basepair') MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1) +This is also equivalent to -Select variants ---------------- - -All the SQLAlchemy Query select variants are available. Here's a -quick summary of these methods: - -- ``get(PK)``: load a single object identified by its primary key - (either a scalar, or a tuple) - -- ``select(Clause, **kwargs)``: perform a select restricted by the - `Clause` argument; returns a list of objects. The most common clause - argument takes the form ``db.tablename.c.columname == value``. The - most common optional argument is `order_by`. - -- ``select_by(**params)``: select methods ending with ``_by`` allow - using bare column names (``columname=value``). This feels more - natural to most Python programmers; the downside is you can't - specify ``order_by`` or other select options. - -- ``selectfirst``, ``selectfirst_by``: returns only the first object - found; equivalent to ``select(...)[0]`` or ``select_by(...)[0]``, - except None is returned if no rows are selected. + >>> db.users.filter_by(name='Bhargan Basepair').one() + MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1) -- ``selectone``, ``selectone_by``: like ``selectfirst`` or - ``selectfirst_by``, but raises if less or more than one object is - selected. +filter_by is like filter, but takes kwargs instead of full clause expressions. +This makes it more concise for simple queries like this, but you can't do +complex queries like the or\_ above or non-equality based comparisons this way. -- ``count``, ``count_by``: returns an integer count of the rows - selected. +Full query documentation +------------------------ -See the SQLAlchemy documentation for details, `datamapping query`__ -for general info and examples, `sql construction`__ for details on -constructing ``WHERE`` clauses. +Get, filter, filter_by, order_by, limit, and the rest of the +query methods are explained in detail in the `SQLAlchemy documentation`__. -__ http://www.sqlalchemy.org/docs/datamapping.myt#datamapping_query -__ http://www.sqlalchemy.org/docs/sqlconstruction.myt +__ http://www.sqlalchemy.org/docs/04/ormtutorial.html#datamapping_querying Modifying objects @@ -125,12 +105,12 @@ multiple updates to a single object will be turned into a single To finish covering the basics, let's insert a new loan, then delete it:: - >>> book_id = db.books.selectfirst(db.books.c.title=='Regional Variation in Moss').id + >>> book_id = db.books.filter_by(title='Regional Variation in Moss').first().id >>> db.loans.insert(book_id=book_id, user_name=user.name) MappedLoans(book_id=2,user_name='Bhargan Basepair',loan_date=None) >>> db.flush() - >>> loan = db.loans.selectone_by(book_id=2, user_name='Bhargan Basepair') + >>> loan = db.loans.filter_by(book_id=2, user_name='Bhargan Basepair').one() >>> db.delete(loan) >>> db.flush() @@ -146,13 +126,13 @@ to the select methods. >>> db.loans.insert(book_id=book_id, user_name=user.name) MappedLoans(book_id=2,user_name='Bhargan Basepair',loan_date=None) >>> db.flush() - >>> db.loans.delete(db.loans.c.book_id==2) + >>> db.loans.delete(db.loans.book_id==2) You can similarly update multiple rows at once. This will change the book_id to 1 in all loans whose book_id is 2:: - >>> db.loans.update(db.loans.c.book_id==2, book_id=1) - >>> db.loans.select_by(db.loans.c.book_id==1) + >>> db.loans.update(db.loans.book_id==2, book_id=1) + >>> db.loans.filter_by(book_id=1).all() [MappedLoans(book_id=1,user_name='Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))] @@ -169,7 +149,7 @@ uses that as the join condition automatically. :: >>> join1 = db.join(db.users, db.loans, isouter=True) - >>> join1.select_by(name='Joe Student') + >>> join1.filter_by(name='Joe Student').all() [MappedJoin(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0,book_id=1,user_name='Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))] If you're unfortunate enough to be using MySQL with the default MyISAM @@ -177,7 +157,7 @@ storage engine, you'll have to specify the join condition manually, since MyISAM does not store foreign keys. Here's the same join again, with the join condition explicitly specified:: - >>> db.join(db.users, db.loans, db.users.c.name==db.loans.c.user_name, isouter=True) + >>> db.join(db.users, db.loans, db.users.name==db.loans.user_name, isouter=True) You can compose arbitrarily complex joins by combining Join objects @@ -185,11 +165,12 @@ with tables or other joins. Here we combine our first join with the books table:: >>> join2 = db.join(join1, db.books) - >>> join2.select() + >>> join2.all() [MappedJoin(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0,book_id=1,user_name='Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0),id=1,title='Mustards I Have Known',published_year='1989',authors='Jones')] If you join tables that have an identical column name, wrap your join -with `with_labels`, to disambiguate columns with their table name:: +with `with_labels`, to disambiguate columns with their table name +(.c is short for .columns):: >>> db.with_labels(join1).c.keys() [u'users_name', u'users_email', u'users_password', u'users_classname', u'users_admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date'] @@ -201,6 +182,28 @@ You can also join directly to a labeled object:: [u'name', u'email', u'password', u'classname', u'admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date'] +Relations +========= + +You can define relations on SqlSoup classes: + + >>> db.users.relate('loans', db.loans) + +These can then be used like a normal SA property: + + >>> db.users.get('Joe Student').loans + [MappedLoans(book_id=1,user_name='Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))] + + >>> db.users.filter(~db.users.loans.any()).all() + [MappedUsers(name='Bhargan Basepair',email='basepair+nospam@example.edu',password='basepair',classname=None,admin=1)] + + +relate can take any options that the relation function accepts in normal mapper definition: + + >>> del db._cache['users'] + >>> db.users.relate('loans', db.loans, order_by=db.loans.loan_date, cascade='all, delete-orphan') + + Advanced Use ============ @@ -237,7 +240,7 @@ PK in the database.) >>> s = select([b.c.published_year, func.count('*').label('n')], from_obj=[b], group_by=[b.c.published_year]) >>> s = s.alias('years_with_count') >>> years_with_count = db.map(s, primary_key=[s.c.published_year]) - >>> years_with_count.select_by(published_year='1989') + >>> years_with_count.filter_by(published_year='1989').all() [MappedBooks(published_year='1989',n=1)] Obviously if we just wanted to get a list of counts associated with @@ -259,7 +262,7 @@ Raw SQL SqlSoup works fine with SQLAlchemy's `text block support`__. -__ http://www.sqlalchemy.org/docs/documentation.myt#sql_textual +__ http://www.sqlalchemy.org/docs/04/sqlexpression.html#sql_text You can also access the SqlSoup's `engine` attribute to compose SQL directly. The engine's ``execute`` method corresponds to the one of a @@ -274,6 +277,16 @@ you would also see on a cursor:: You can also pass this engine object to other SQLAlchemy constructs. +Dynamic table names +------------------- + +You can load a table whose name is specified at runtime with the entity() method: + + >>> tablename = 'loans' + >>> db.entity(tablename) == db.loans + True + + Extra tests =========== @@ -281,7 +294,7 @@ Boring tests here. Nothing of real expository value. :: - >>> db.users.select(db.users.c.classname==None, order_by=[db.users.c.name]) + >>> db.users.filter_by(classname=None).order_by(db.users.name).all() [MappedUsers(name='Bhargan Basepair',email='basepair+nospam@example.edu',password='basepair',classname=None,admin=1), MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)] >>> db.nopk @@ -310,10 +323,11 @@ Boring tests here. Nothing of real expository value. """ from sqlalchemy import * +from sqlalchemy import schema, sql from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext from sqlalchemy.exceptions import * - +from sqlalchemy.sql import expression _testsql = """ CREATE TABLE books ( @@ -375,13 +389,13 @@ objectstore = Objectstore(create_session) class PKNotFoundError(SQLAlchemyError): pass -# metaclass is necessary to expose class methods with getattr, e.g. -# we want to pass db.users.select through to users._mapper.select def _ddl_error(cls): msg = 'SQLSoup can only modify mapped Tables (found: %s)' \ % cls._table.__class__.__name__ raise InvalidRequestError(msg) +# metaclass is necessary to expose class methods with getattr, e.g. +# we want to pass db.users.select through to users._mapper.select class SelectableClassType(type): def insert(cls, **kwargs): _ddl_error(cls) @@ -413,6 +427,9 @@ class TableClassType(SelectableClassType): def update(cls, whereclause=None, values=None, **kwargs): cls._table.update(whereclause, values).execute(**kwargs) + def relate(cls, propname, *args, **kwargs): + class_mapper(cls)._compile_property(propname, relation(*args, **kwargs)) + def _is_outer_join(selectable): if not isinstance(selectable, sql.Join): return False @@ -434,7 +451,7 @@ def _selectable_name(selectable): return x def class_for_table(selectable, **mapper_kwargs): - selectable = sql._selectable(selectable) + selectable = expression._selectable(selectable) mapname = 'Mapped' + _selectable_name(selectable) if isinstance(selectable, Table): klass = TableClassType(mapname, (object,), {}) @@ -518,13 +535,13 @@ class SqlSoup: def with_labels(self, item): # TODO give meaningful aliases - return self.map(sql._selectable(item).select(use_labels=True).alias('foo')) + return self.map(expression._selectable(item).select(use_labels=True).alias('foo')) def join(self, *args, **kwargs): j = join(*args, **kwargs) return self.map(j) - def __getattr__(self, attr): + def entity(self, attr): try: t = self._cache[attr] except KeyError: @@ -538,9 +555,14 @@ class SqlSoup: self._cache[attr] = t return t + def __getattr__(self, attr): + return self.entity(attr) + def __repr__(self): return 'SqlSoup(%r)' % self._metadata if __name__ == '__main__': + import logging + logging.basicConfig() import doctest doctest.testmod() diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py new file mode 100644 index 0000000000..eaad257698 --- /dev/null +++ b/lib/sqlalchemy/interfaces.py @@ -0,0 +1,89 @@ +# interfaces.py +# Copyright (C) 2007 Jason Kirtland jek@discorporate.us +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Interfaces and abstract types.""" + + +class PoolListener(object): + """Hooks into the lifecycle of connections in a ``Pool``. + + All of the standard connection [sqlalchemy.pool#Pool] types can + accept event listeners for key connection lifecycle events: + creation, pool check-out and check-in. There are no events fired + when a connection closes. + + For any given DB-API connection, there will be one ``connect`` + event, `n` number of ``checkout`` events, and either `n` or `n - 1` + ``checkin`` events. (If a ``Connection`` is detached from its + pool via the ``detach()`` method, it won't be checked back in.) + + These are low-level events for low-level objects: raw Python + DB-API connections, without the conveniences of the SQLAlchemy + ``Connection`` wrapper, ``Dialect`` services or ``ClauseElement`` + execution. If you execute SQL through the connection, explicitly + closing all cursors and other resources is recommended. + + Events also receive a ``_ConnectionRecord``, a long-lived internal + ``Pool`` object that basically represents a "slot" in the + connection pool. ``_ConnectionRecord`` objects have one public + attribute of note: ``info``, a dictionary whose contents are + scoped to the lifetime of the DB-API connection managed by the + record. You can use this shared storage area however you like. + + There is no need to subclass ``PoolListener`` to handle events. + Any class that implements one or more of these methods can be used + as a pool listener. The ``Pool`` will inspect the methods + provided by a listener object and add the listener to one or more + internal event queues based on its capabilities. In terms of + efficiency and function call overhead, you're much better off only + providing implementations for the hooks you'll be using. + """ + + def connect(self, dbapi_con, con_record): + """Called once for each new DB-API connection or Pool's ``creator()``. + + dbapi_con + A newly connected raw DB-API connection (not a SQLAlchemy + ``Connection`` wrapper). + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ + + def checkout(self, dbapi_con, con_record, con_proxy): + """Called when a connection is retrieved from the Pool. + + dbapi_con + A raw DB-API connection + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + con_proxy + The ``_ConnectionFairy`` which manages the connection for the span of + the current checkout. + + If you raise an ``exceptions.DisconnectionError``, the current + connection will be disposed and a fresh connection retrieved. + Processing of all checkout listeners will abort and restart + using the new connection. + """ + + def checkin(self, dbapi_con, con_record): + """Called when a connection returns to the pool. + + Note that the connection may be closed, and may be None if the + connection has been invalidated. ``checkin`` will not be called + for detached connections. (They do not return to the pool.) + + dbapi_con + A raw DB-API connection + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ diff --git a/lib/sqlalchemy/logging.py b/lib/sqlalchemy/logging.py index f02e3f746f..13872caa38 100644 --- a/lib/sqlalchemy/logging.py +++ b/lib/sqlalchemy/logging.py @@ -1,20 +1,21 @@ # logging.py - adapt python logging module to SQLAlchemy -# Copyright (C) 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Provides a few functions used by instances to turn on/off their -logging, including support for the usual "echo" parameter. +"""Logging control and utilities. -Control of logging for SA can be performed from the regular python -logging module. The regular dotted module namespace is used, starting -at 'sqlalchemy'. For class-level logging, the class name is appended, -and for instance-level logging, the hex id of the instance is -appended. +Provides a few functions used by instances to turn on/off their logging, +including support for the usual "echo" parameter. -The "echo" keyword parameter which is available on some SA objects -corresponds to an instance-level logger for that instance. +Control of logging for SA can be performed from the regular python logging +module. The regular dotted module namespace is used, starting at +'sqlalchemy'. For class-level logging, the class name is appended, and for +instance-level logging, the hex id of the instance is appended. + +The "echo" keyword parameter which is available on some SA objects corresponds +to an instance-level logger for that instance. E.g.:: @@ -23,16 +24,23 @@ E.g.:: is equivalent to:: import logging - logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine))).setLevel(logging.DEBUG) + logger = logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine))) + logger.setLevel(logging.DEBUG) """ -import sys +import sys, warnings +import sqlalchemy.exceptions as sa_exc # py2.5 absolute imports will fix.... logging = __import__('logging') +# moved to sqlalchemy.exceptions. this alias will be removed in 0.5. +SADeprecationWarning = sa_exc.SADeprecationWarning -logging.getLogger('sqlalchemy').setLevel(logging.WARN) +rootlogger = logging.getLogger('sqlalchemy') +if rootlogger.level == logging.NOTSET: + rootlogger.setLevel(logging.WARN) +warnings.filterwarnings("once", category=sa_exc.SADeprecationWarning) default_enabled = False def default_logging(name): @@ -41,19 +49,21 @@ def default_logging(name): default_enabled=True if not default_enabled: default_enabled = True - rootlogger = logging.getLogger('sqlalchemy') handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(name)s %(message)s')) + handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)s %(name)s %(message)s')) rootlogger.addHandler(handler) def _get_instance_name(instance): - # since getLogger() does not have any way of removing logger objects from memory, - # instance logging displays the instance id as a modulus of 16 to prevent endless memory growth - # also speeds performance as logger initialization is apparently slow - return instance.__class__.__module__ + "." + instance.__class__.__name__ + ".0x.." + hex(id(instance))[-2:] - -def instance_logger(instance): - return logging.getLogger(_get_instance_name(instance)) + # since getLogger() does not have any way of removing logger objects from + # memory, instance logging displays the instance id as a modulus of 16 to + # prevent endless memory growth also speeds performance as logger + # initialization is apparently slow + return "%s.%s.0x..%s" % (instance.__class__.__module__, + instance.__class__.__name__, + hex(id(instance))[-2:]) + return (instance.__class__.__module__ + "." + instance.__class__.__name__ + + ".0x.." + hex(id(instance))[-2:]) def class_logger(cls): return logging.getLogger(cls.__module__ + "." + cls.__name__) @@ -64,27 +74,40 @@ def is_debug_enabled(logger): def is_info_enabled(logger): return logger.isEnabledFor(logging.INFO) +def instance_logger(instance, echoflag=None): + if echoflag is not None: + l = logging.getLogger(_get_instance_name(instance)) + if echoflag == 'debug': + default_logging(_get_instance_name(instance)) + l.setLevel(logging.DEBUG) + elif echoflag is True: + default_logging(_get_instance_name(instance)) + l.setLevel(logging.INFO) + elif echoflag is False: + l.setLevel(logging.NOTSET) + else: + l = logging.getLogger(_get_instance_name(instance)) + instance._should_log_debug = l.isEnabledFor(logging.DEBUG) + instance._should_log_info = l.isEnabledFor(logging.INFO) + return l + class echo_property(object): - level_map={logging.DEBUG : "debug", logging.INFO:True} - - __doc__ = """when ``True``, enable log output for this element. - - This has the effect of setting the Python logging level for the - namespace of this element's class and object reference. A value - of boolean ``True`` indicates that the loglevel ``logging.INFO`` will be - set for the logger, whereas the string value ``debug`` will set the loglevel - to ``logging.DEBUG``. + __doc__ = """\ + When ``True``, enable log output for this element. + + + This has the effect of setting the Python logging level for the namespace + of this element's class and object reference. A value of boolean ``True`` + indicates that the loglevel ``logging.INFO`` will be set for the logger, + whereas the string value ``debug`` will set the loglevel to + ``logging.DEBUG``. """ - + def __get__(self, instance, owner): if instance is None: return self - level = logging.getLogger(_get_instance_name(instance)).getEffectiveLevel() - return echo_property.level_map.get(level, False) - - def __set__(self, instance, value): - if value: - default_logging(_get_instance_name(instance)) - logging.getLogger(_get_instance_name(instance)).setLevel(value == 'debug' and logging.DEBUG or logging.INFO) else: - logging.getLogger(_get_instance_name(instance)).setLevel(logging.NOTSET) + return instance._should_log_debug and 'debug' or (instance._should_log_info and True or False) + + def __set__(self, instance, value): + instance_logger(instance, echoflag=value) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 1982a94f78..2466a27637 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -1,39 +1,89 @@ # mapper/__init__.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """ -The mapper package provides object-relational functionality, building upon the schema and sql -packages and tying operations to class properties and constructors. +Functional constructs for ORM configuration. + +See the SQLAlchemy object relational tutorial and mapper configuration +documentation for an overview of how this module is used. """ -from sqlalchemy import exceptions -from sqlalchemy import util as sautil -from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, mapper_registry -from sqlalchemy.orm.interfaces import SynonymProperty, MapperExtension, EXT_PASS, ExtensionOption, PropComparator -from sqlalchemy.orm.properties import PropertyLoader, ColumnProperty, CompositeProperty, BackRef +from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, _mapper_registry +from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, EXT_STOP, EXT_PASS, ExtensionOption, PropComparator +from sqlalchemy.orm.properties import SynonymProperty, ComparableProperty, PropertyLoader, ColumnProperty, CompositeProperty, BackRef from sqlalchemy.orm import mapper as mapperlib -from sqlalchemy.orm import collections, strategies -from sqlalchemy.orm.query import Query -from sqlalchemy.orm.util import polymorphic_union -from sqlalchemy.orm.session import Session as create_session -from sqlalchemy.orm.session import object_session, attribute_manager - -__all__ = ['relation', 'column_property', 'composite', 'backref', 'eagerload', - 'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', - 'undefer_group', 'extension', 'mapper', 'clear_mappers', - 'compile_mappers', 'class_mapper', 'object_mapper', - 'MapperExtension', 'Query', 'polymorphic_union', 'create_session', - 'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS', - 'object_session', 'PropComparator' - ] +from sqlalchemy.orm import strategies +from sqlalchemy.orm.query import Query, aliased +from sqlalchemy.orm.util import polymorphic_union, create_row_adapter +from sqlalchemy.orm.session import Session as _Session +from sqlalchemy.orm.session import object_session, sessionmaker +from sqlalchemy.orm.scoping import ScopedSession + + +__all__ = [ 'relation', 'column_property', 'composite', 'backref', 'eagerload', + 'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer', + 'undefer', 'undefer_group', 'extension', 'mapper', 'clear_mappers', + 'compile_mappers', 'class_mapper', 'object_mapper', 'sessionmaker', + 'scoped_session', 'dynamic_loader', 'MapperExtension', + 'polymorphic_union', 'comparable_property', + 'create_session', 'synonym', 'contains_alias', 'Query', 'aliased', + 'contains_eager', 'EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', + 'object_session', 'PropComparator' ] + + +def scoped_session(session_factory, scopefunc=None): + """Provides thread-local management of Sessions. + + This is a front-end function to the [sqlalchemy.orm.scoping#ScopedSession] + class. + + Usage:: + + Session = scoped_session(sessionmaker(autoflush=True)) + + To instantiate a Session object which is part of the scoped + context, instantiate normally:: + + session = Session() + + Most session methods are available as classmethods from + the scoped session:: + + Session.commit() + Session.close() + + To map classes so that new instances are saved in the current + Session automatically, as well as to provide session-aware + class attributes such as "query", use the `mapper` classmethod + from the scoped session:: + + mapper = Session.mapper + mapper(Class, table, ...) + + """ + + return ScopedSession(session_factory, scopefunc=scopefunc) + +def create_session(bind=None, **kwargs): + """create a new [sqlalchemy.orm.session#Session]. + + The session by default does not begin a transaction, and requires that + flush() be called explicitly in order to persist results to the database. + + It is recommended to use the [sqlalchemy.orm#sessionmaker()] function + instead of create_session(). + """ + kwargs.setdefault('autoflush', False) + kwargs.setdefault('transactional', False) + return _Session(bind=bind, **kwargs) def relation(argument, secondary=None, **kwargs): """Provide a relationship of a primary Mapper to a secondary Mapper. - This corresponds to a parent-child or associative table relationship. + This corresponds to a parent-child or associative table relationship. The constructed class is an instance of [sqlalchemy.orm.properties#PropertyLoader]. argument @@ -50,28 +100,28 @@ def relation(argument, secondary=None, **kwargs): \**kwargs follow: association - Deprecated; as of version 0.3.0 the association keyword is synonomous + Deprecated; as of version 0.3.0 the association keyword is synonymous with applying the "all, delete-orphan" cascade to a "one-to-many" relationship. SA can now automatically reconcile a "delete" and "insert" operation of two objects with the same "identity" in a flush() operation into a single "update" statement, which is the pattern that - "association" used to indicate. - + "association" used to indicate. + backref indicates the name of a property to be placed on the related mapper's class that will handle this relationship in the other direction, including synchronizing the object attributes on both sides of the relation. Can also point to a ``backref()`` construct for more - configurability. - + configurability. + cascade a string list of cascade rules which determines how persistence - operations should be "cascaded" from parent to child. - + operations should be "cascaded" from parent to child. + collection_class a class or function that returns a new list-holding object. will be - used in place of a plain list for storing elements. - + used in place of a plain list for storing elements. + foreign_keys a list of columns which are to be used as "foreign key" columns. this parameter should be used in conjunction with explicit @@ -87,46 +137,97 @@ def relation(argument, secondary=None, **kwargs): deprecated. use the ``foreign_keys`` argument for foreign key specification, or ``remote_side`` for "directional" logic. - lazy=True - specifies how the related items should be loaded. a value of True - indicates they should be loaded lazily when the property is first - accessed. A value of False indicates they should be loaded by joining - against the parent object query, so parent and child are loaded in one - round trip (i.e. eagerly). A value of None indicates the related items - are not loaded by the mapper in any case; the application will manually - insert items into the list in some other way. In all cases, items added - or removed to the parent object's collection (or scalar attribute) will - cause the appropriate updates and deletes upon flush(), i.e. this - option only affects load operations, not save operations. + join_depth=None + when non-``None``, an integer value indicating how many levels + deep eagerload joins should be constructed on a self-referring + or cyclical relationship. The number counts how many times + the same Mapper shall be present in the loading condition along + a particular join branch. When left at its default of ``None``, + eager loads will automatically stop chaining joins when they encounter + a mapper which is already higher up in the chain. + + lazy=(True|False|None|'dynamic') + specifies how the related items should be loaded. Values include: + + True - items should be loaded lazily when the property is first + accessed. + + False - items should be loaded "eagerly" in the same query as that + of the parent, using a JOIN or LEFT OUTER JOIN. + + None - no loading should occur at any time. This is to support + "write-only" attributes, or attributes which are populated in + some manner specific to the application. + + 'dynamic' - a ``DynaLoader`` will be attached, which returns a + ``Query`` object for all read operations. The dynamic- + collection supports only ``append()`` and ``remove()`` + for write operations; changes to the dynamic property + will not be visible until the data is flushed to the + database. order_by indicates the ordering that should be applied when loading these items. passive_deletes=False - Indicates if lazy-loaders should not be executed during the ``flush()`` - process, which normally occurs in order to locate all existing child - items when a parent item is to be deleted. Setting this flag to True is - appropriate when ``ON DELETE CASCADE`` rules have been set up on the - actual tables so that the database may handle cascading deletes - automatically. This strategy is useful particularly for handling the - deletion of objects that have very large (and/or deep) child-object - collections. + Indicates loading behavior during delete operations. + + A value of True indicates that unloaded child items should not be loaded + during a delete operation on the parent. Normally, when a parent + item is deleted, all child items are loaded so that they can either be + marked as deleted, or have their foreign key to the parent set to NULL. + Marking this flag as True usually implies an ON DELETE + rule is in place which will handle updating/deleting child rows on the + database side. + + Additionally, setting the flag to the string value 'all' will disable + the "nulling out" of the child foreign keys, when there is no delete or + delete-orphan cascade enabled. This is typically used when a triggering + or error raise scenario is in place on the database side. Note that + the foreign key attributes on in-session child objects will not be changed + after a flush occurs so this is a very special use-case setting. + + passive_updates=True + Indicates loading and INSERT/UPDATE/DELETE behavior when the source + of a foreign key value changes (i.e. an "on update" cascade), which + are typically the primary key columns of the source row. + + When True, it is assumed that ON UPDATE CASCADE is configured on the + foreign key in the database, and that the database will handle + propagation of an UPDATE from a source column to dependent rows. + Note that with databases which enforce referential integrity + (i.e. Postgres, MySQL with InnoDB tables), ON UPDATE CASCADE is + required for this operation. The relation() will update the value + of the attribute on related items which are locally present in the + session during a flush. + + When False, it is assumed that the database does not enforce + referential integrity and will not be issuing its own CASCADE + operation for an update. The relation() will issue the appropriate + UPDATE statements to the database in response to the change of a + referenced key, and items locally present in the session during a + flush will also be refreshed. + + This flag should probably be set to False if primary key changes are + expected and the database in use doesn't support CASCADE + (i.e. SQLite, MySQL MyISAM tables). post_update this indicates that the relationship should be handled by a second - UPDATE statement after an INSERT or before a DELETE. Currently, it also - will issue an UPDATE after the instance was UPDATEd as well, although - this technically should be improved. This flag is used to handle saving - bi-directional dependencies between two individual rows (i.e. each row - references the other), where it would otherwise be impossible to INSERT - or DELETE both rows fully since one row exists before the other. Use - this flag when a particular mapping arrangement will incur two rows - that are dependent on each other, such as a table that has a - one-to-many relationship to a set of child rows, and also has a column - that references a single child row within that list (i.e. both tables - contain a foreign key to each other). If a ``flush()`` operation returns - an error that a "cyclical dependency" was detected, this is a cue that - you might want to use ``post_update`` to "break" the cycle. + UPDATE statement after an INSERT or before a DELETE. Currently, it + also will issue an UPDATE after the instance was UPDATEd as well, + although this technically should be improved. This flag is used to + handle saving bi-directional dependencies between two individual + rows (i.e. each row references the other), where it would otherwise + be impossible to INSERT or DELETE both rows fully since one row + exists before the other. Use this flag when a particular mapping + arrangement will incur two rows that are dependent on each other, + such as a table that has a one-to-many relationship to a set of + child rows, and also has a column that references a single child row + within that list (i.e. both tables contain a foreign key to each + other). If a ``flush()`` operation returns an error that a "cyclical + dependency" was detected, this is a cue that you might want to use + ``post_update`` to "break" the cycle. primaryjoin a ClauseElement that will be used as the primary join of this child @@ -138,11 +239,11 @@ def relation(argument, secondary=None, **kwargs): private=False deprecated. setting ``private=True`` is the equivalent of setting ``cascade="all, delete-orphan"``, and indicates the lifecycle of child - objects should be contained within that of the parent. + objects should be contained within that of the parent. remote_side used for self-referential relationships, indicates the column or list - of columns that form the "remote side" of the relationship. + of columns that form the "remote side" of the relationship. secondaryjoin a ClauseElement that will be used as the join of an association table @@ -170,7 +271,28 @@ def relation(argument, secondary=None, **kwargs): return PropertyLoader(argument, secondary=secondary, **kwargs) -# return _relation_loader(argument, secondary=secondary, **kwargs) +def dynamic_loader(argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, + foreign_keys=None, backref=None, post_update=False, cascade=None, remote_side=None, enable_typechecks=True, + passive_deletes=False, order_by=None): + """construct a dynamically-loading mapper property. + + This property is similar to relation(), except read operations + return an active Query object, which reads from the database in all + cases. Items may be appended to the attribute via append(), or + removed via remove(); changes will be persisted + to the database during a flush(). However, no other list mutation + operations are available. + + A subset of arguments available to relation() are available here. + """ + + from sqlalchemy.orm.dynamic import DynaLoader + + return PropertyLoader(argument, secondary=secondary, primaryjoin=primaryjoin, + secondaryjoin=secondaryjoin, entity_name=entity_name, foreign_keys=foreign_keys, backref=backref, + post_update=post_update, cascade=cascade, remote_side=remote_side, enable_typechecks=enable_typechecks, + passive_deletes=passive_deletes, order_by=order_by, + strategy_class=DynaLoader) #def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs): @@ -183,57 +305,73 @@ def column_property(*args, **kwargs): the mapper's selectable; examples include SQL expressions, functions, and scalar SELECT queries. - Columns that arent present in the mapper's selectable won't be persisted + Columns that aren't present in the mapper's selectable won't be persisted by the mapper and are effectively "read-only" attributes. \*cols list of Column objects to be mapped. - + group a group name for this property when marked as deferred. - + deferred when True, the column property is "deferred", meaning that it does not load immediately, and is instead loaded when the - attribute is first accessed on an instance. See also + attribute is first accessed on an instance. See also [sqlalchemy.orm#deferred()]. """ - + return ColumnProperty(*args, **kwargs) def composite(class_, *cols, **kwargs): """Return a composite column-based property for use with a Mapper. - + This is very much like a column-based property except the given class - is used to construct values composed of one or more columns. The class must - implement a constructor with positional arguments matching the order of - columns given, as well as a __colset__() method which returns its attributes - in column order. - - class\_ - the "composite type" class. - - \*cols - list of Column objects to be mapped. - - group - a group name for this property when marked as deferred. - - deferred - when True, the column property is "deferred", meaning that - it does not load immediately, and is instead loaded when the - attribute is first accessed on an instance. See also - [sqlalchemy.orm#deferred()]. - - comparator - an optional instance of [sqlalchemy.orm#PropComparator] which - provides SQL expression generation functions for this composite - type. + is used to represent "composite" values composed of one or more columns. + + The class must implement a constructor with positional arguments matching + the order of columns supplied here, as well as a __composite_values__() + method which returns values in the same order. + + A simple example is representing separate two columns in a table as a + single, first-class "Point" object:: + + class Point(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __composite_values__(self): + return (self.x, self.y) + + # and then in the mapping: + ... composite(Point, mytable.c.x, mytable.c.y) ... + + Arguments are: + + class\_ + The "composite type" class. + + \*cols + List of Column objects to be mapped. + + group + A group name for this property when marked as deferred. + + deferred + When True, the column property is "deferred", meaning that + it does not load immediately, and is instead loaded when the + attribute is first accessed on an instance. See also + [sqlalchemy.orm#deferred()]. + + comparator + An optional instance of [sqlalchemy.orm#PropComparator] which + provides SQL expression generation functions for this composite + type. """ - + return CompositeProperty(class_, *cols, **kwargs) - + def backref(name, **kwargs): """Create a BackRef object with explicit arguments, which are the same arguments one @@ -275,7 +413,7 @@ def mapper(class_, local_table=None, *args, **params): overwrite all data within object instances that already exist within the session, erasing any in-memory changes with whatever information was loaded from the database. Usage - of this flag is highly discouraged; as an alternative, + of this flag is highly discouraged; as an alternative, see the method `populate_existing()` on [sqlalchemy.orm.query#Query]. allow_column_override @@ -323,6 +461,11 @@ def mapper(class_, local_table=None, *args, **params): ``ClauseElement``) which will define how the two tables are joined; defaults to a natural join between the two tables. + inherit_foreign_keys + when inherit_condition is used and the condition contains no + ForeignKey columns, specify the "foreign" columns of the join + condition in this list. else leave as None. + order_by A single ``Column`` or list of ``Columns`` for which selection operations should use as the default ordering for @@ -337,12 +480,16 @@ def mapper(class_, local_table=None, *args, **params): polymorphic_on Used with mappers in an inheritance relationship, a ``Column`` which will identify the class/mapper combination to be used - with a particular row. requires the polymorphic_identity + with a particular row. Requires the ``polymorphic_identity`` value to be set for all mappers in the inheritance - hierarchy. + hierarchy. The column specified by ``polymorphic_on`` is + usually a column that resides directly within the base + mapper's mapped table; alternatively, it may be a column + that is only present within the portion + of the ``with_polymorphic`` argument. _polymorphic_map - Used internally to propigate the full map of polymorphic + Used internally to propagate the full map of polymorphic identifiers to surrogate mappers. polymorphic_identity @@ -351,9 +498,9 @@ def mapper(class_, local_table=None, *args, **params): this mapper. polymorphic_fetch - specifies how subclasses mapped through joined-table - inheritance will be fetched. options are 'union', - 'select', and 'deferred'. if the select_table argument + specifies how subclasses mapped through joined-table + inheritance will be fetched. options are 'union', + 'select', and 'deferred'. if the 'with_polymorphic' argument is present, defaults to 'union', otherwise defaults to 'select'. @@ -366,17 +513,46 @@ def mapper(class_, local_table=None, *args, **params): each ``Column`` (although they can be overridden using this dictionary). + include_properties + An inclusive list of properties to map. Columns present in the + mapped table but not present in this list will not be automatically + converted into properties. + + exclude_properties + A list of properties not to map. Columns present in the + mapped table and present in this list will not be automatically + converted into properties. Note that neither this option nor + include_properties will allow an end-run around Python inheritance. + If mapped class ``B`` inherits from mapped class ``A``, no combination + of includes or excludes will allow ``B`` to have fewer properties + than its superclass, ``A``. + primary_key A list of ``Column`` objects which define the *primary key* to be used against this mapper's selectable unit. This is normally simply the primary key of the `local_table`, but can be overridden here. - + + with_polymorphic + A tuple in the form ``(, )`` indicating the + default style of "polymorphic" loading, that is, which tables + are queried at once. is any single or list of mappers + and/or classes indicating the inherited classes that should be + loaded at once. The special value ``'*'`` may be used to indicate + all descending classes should be loaded immediately. The second + tuple argument indicates a selectable that will be + used to query for multiple classes. Normally, it is left as + None, in which case this mapper will form an outer join from + the base mapper's table to that of all desired sub-mappers. + When specified, it provides the selectable to be used for + polymorphic loading. When with_polymorphic includes mappers + which load from a "concrete" inheriting table, the + argument is required, since it usually requires more complex + UNION queries. + select_table - A [sqlalchemy.schema#Table] or any [sqlalchemy.sql#Selectable] - which will be used to select instances of this mapper's class. - usually used to provide polymorphic loading among several - classes in an inheritance hierarchy. + Deprecated. Synonymous with + ``with_polymorphic=('*', )``. version_id_col A ``Column`` which must have an integer type that will be @@ -389,13 +565,85 @@ def mapper(class_, local_table=None, *args, **params): return Mapper(class_, local_table, *args, **params) -def synonym(name, proxy=False): - """Set up `name` as a synonym to another ``MapperProperty``. - - Used with the `properties` dictionary sent to ``mapper()``. +def synonym(name, map_column=False, descriptor=None, proxy=False): + """Set up `name` as a synonym to another mapped property. + + Used with the ``properties`` dictionary sent to [sqlalchemy.orm#mapper()]. + + Any existing attributes on the class which map the key name sent + to the ``properties`` dictionary will be used by the synonym to + provide instance-attribute behavior (that is, any Python property object, + provided by the ``property`` builtin or providing a ``__get__()``, + ``__set__()`` and ``__del__()`` method). If no name exists for the key, + the ``synonym()`` creates a default getter/setter object automatically + and applies it to the class. + + `name` refers to the name of the existing mapped property, which + can be any other ``MapperProperty`` including column-based + properties and relations. + + If `map_column` is ``True``, an additional ``ColumnProperty`` is created + on the mapper automatically, using the synonym's name as the keyname of + the property, and the keyname of this ``synonym()`` as the name of the + column to map. For example, if a table has a column named ``status``:: + + class MyClass(object): + def _get_status(self): + return self._status + def _set_status(self, value): + self._status = value + status = property(_get_status, _set_status) + + mapper(MyClass, sometable, properties={ + "status":synonym("_status", map_column=True) + }) + + The column named ``status`` will be mapped to the attribute named + ``_status``, and the ``status`` attribute on ``MyClass`` will be used to + proxy access to the column-based attribute. + + The `proxy` keyword argument is deprecated and currently does nothing; + synonyms now always establish an attribute getter/setter function if one + is not already available. """ - return SynonymProperty(name, proxy=proxy) + return SynonymProperty(name, map_column=map_column, descriptor=descriptor) + +def comparable_property(comparator_factory, descriptor=None): + """Provide query semantics for an unmanaged attribute. + + Allows a regular Python @property (descriptor) to be used in Queries and + SQL constructs like a managed attribute. comparable_property wraps a + descriptor with a proxy that directs operator overrides such as == + (__eq__) to the supplied comparator but proxies everything else through + to the original descriptor:: + + class MyClass(object): + @property + def myprop(self): + return 'foo' + + class MyComparator(sqlalchemy.orm.interfaces.PropComparator): + def __eq__(self, other): + .... + + mapper(MyClass, mytable, properties=dict( + 'myprop': comparable_property(MyComparator))) + + Used with the ``properties`` dictionary sent to [sqlalchemy.orm#mapper()]. + + comparator_factory + A PropComparator subclass or factory that defines operator behavior + for this property. + + descriptor + Optional when used in a ``properties={}`` declaration. The Python + descriptor or property to layer comparison behavior on top of. + + The like-named descriptor will be automatically retreived from the + mapped class if left blank in a ``properties`` declaration. + """ + return ComparableProperty(comparator_factory, descriptor) def compile_mappers(): """Compile all mappers that have been defined. @@ -403,28 +651,22 @@ def compile_mappers(): This is equivalent to calling ``compile()`` on any individual mapper. """ - if not len(mapper_registry): - return - mapper_registry.values()[0].compile() + for m in list(_mapper_registry): + m.compile() def clear_mappers(): """Remove all mappers that have been created thus far. - When new mappers are created, they will be assigned to their - classes as their primary mapper. + The mapped classes will return to their initial "unmapped" + state and can be re-mapped with new mappers. """ - mapperlib._COMPILE_MUTEX.acquire() try: - for mapper in mapper_registry.values(): + for mapper in list(_mapper_registry): mapper.dispose() - mapper_registry.clear() - # TODO: either dont use ArgSingleton, or - # find a way to clear only ClassKey instances from it - sautil.ArgSingleton.instances.clear() finally: mapperlib._COMPILE_MUTEX.release() - + def extension(ext): """Return a ``MapperOption`` that will insert the given ``MapperExtension`` to the beginning of the list of extensions @@ -435,43 +677,38 @@ def extension(ext): return ExtensionOption(ext) -def eagerload(name): - """Return a ``MapperOption`` that will convert the property of the - given name into an eager load. +def eagerload(name, mapper=None): + """Return a ``MapperOption`` that will convert the property of the given name into an eager load. Used with ``query.options()``. """ - return strategies.EagerLazyOption(name, lazy=False) + return strategies.EagerLazyOption(name, lazy=False, mapper=mapper) + +def eagerload_all(name, mapper=None): + """Return a ``MapperOption`` that will convert all properties along the given dot-separated path into an eager load. + + For example, this:: -def eagerload_all(name): - """Return a ``MapperOption`` that will convert all - properties along the given dot-separated path into an - eager load. - - e.g:: query.options(eagerload_all('orders.items.keywords'))... - + will set all of 'orders', 'orders.items', and 'orders.items.keywords' to load in one eager load. Used with ``query.options()``. """ - return strategies.EagerLazyOption(name, lazy=False, chained=True) + return strategies.EagerLazyOption(name, lazy=False, chained=True, mapper=mapper) -def lazyload(name): +def lazyload(name, mapper=None): """Return a ``MapperOption`` that will convert the property of the given name into a lazy load. Used with ``query.options()``. """ - return strategies.EagerLazyOption(name, lazy=True) + return strategies.EagerLazyOption(name, lazy=True, mapper=mapper) -def fetchmode(name, type): - return strategies.FetchModeOption(name, type) - def noload(name): """Return a ``MapperOption`` that will convert the property of the given name into a non-load. @@ -493,21 +730,14 @@ def contains_alias(alias): def __init__(self, alias): self.alias = alias if isinstance(self.alias, basestring): - self.selectable = None + self.translator = None else: - self.selectable = alias - def get_selectable(self, mapper): - if self.selectable is None: - self.selectable = mapper.mapped_table.alias(self.alias) - return self.selectable + self.translator = create_row_adapter(alias) + def translate_row(self, mapper, context, row): - newrow = sautil.DictDecorator(row) - selectable = self.get_selectable(mapper) - for c in mapper.mapped_table.c: - c2 = selectable.corresponding_column(c, keys_ok=True, raiseerr=False) - if c2 and row.has_key(c2): - newrow[c] = row[c2] - return newrow + if not self.translator: + self.translator = create_row_adapter(mapper.mapped_table.alias(self.alias)) + return self.translator(row) return ExtensionOption(AliasedRow(alias)) @@ -517,7 +747,7 @@ def contains_eager(key, alias=None, decorator=None): Used when feeding SQL result sets directly into ``query.instances()``. Also bundles an ``EagerLazyOption`` to - turn on eager loading in case it isnt already. + turn on eager loading in case it isn't already. `alias` is the string name of an alias, **or** an ``sql.Alias`` object, which represents the aliased columns in the query. This @@ -548,10 +778,9 @@ def undefer(name): return strategies.DeferredOption(name, defer=False) def undefer_group(name): - """Return a ``MapperOption`` that will convert the given + """Return a ``MapperOption`` that will convert the given group of deferred column properties into a non-deferred (regular column) load. Used with ``query.options()``. """ return strategies.UndeferGroupOption(name) - diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 47ff260853..fb0621a70f 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1,100 +1,203 @@ # attributes.py - manages object attributes -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import weakref - +import operator, weakref +from itertools import chain +import UserDict from sqlalchemy import util -from sqlalchemy.orm import util as orm_util, interfaces, collections -from sqlalchemy.orm.mapper import class_mapper -from sqlalchemy import logging, exceptions - +from sqlalchemy.orm import interfaces, collections +from sqlalchemy.orm.util import identity_equal +from sqlalchemy import exceptions -PASSIVE_NORESULT = object() -ATTR_WAS_SET = object() +PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT') +ATTR_WAS_SET = util.symbol('ATTR_WAS_SET') +NO_VALUE = util.symbol('NO_VALUE') +NEVER_SET = util.symbol('NEVER_SET') class InstrumentedAttribute(interfaces.PropComparator): - """attribute access for instrumented classes.""" - - def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs): + """public-facing instrumented attribute, placed in the + class dictionary. + + """ + + def __init__(self, impl, comparator=None): """Construct an InstrumentedAttribute. - - class_ - the class to be instrumented. - - manager - AttributeManager managing this class - - key - string name of the attribute - - callable_ - optional function which generates a callable based on a parent - instance, which produces the "default" values for a scalar or - collection attribute when it's first accessed, if not present already. - - trackparent - if True, attempt to track if an instance has a parent attached to it - via this attribute - - extension - an AttributeExtension object which will receive - set/delete/append/remove/etc. events - - compare_function - a function that compares two values which are normally assignable to this - attribute - - mutable_scalars - if True, the values which are normally assignable to this attribute can mutate, - and need to be compared against a copy of their original contents in order to - detect changes on the parent instance - - comparator - a sql.Comparator to which class-level compare/math events will be sent - + comparator + a sql.Comparator to which class-level compare/math events will be sent """ - - self.class_ = class_ - self.manager = manager - self.key = key - self.callable_ = callable_ - self.trackparent = trackparent - self.mutable_scalars = mutable_scalars + + self.impl = impl self.comparator = comparator - self.copy = None - if compare_function is None: - self.is_equal = lambda x,y: x == y - else: - self.is_equal = compare_function - self.extensions = util.to_list(extension or []) - def __set__(self, obj, value): - self.set(obj, value, None) + def __set__(self, instance, value): + self.impl.set(instance._state, value, None) - def __delete__(self, obj): - self.delete(None, obj) + def __delete__(self, instance): + self.impl.delete(instance._state) - def __get__(self, obj, owner): - if obj is None: + def __get__(self, instance, owner): + if instance is None: return self - return self.get(obj) + return self.impl.get(instance._state) + + def get_history(self, instance, **kwargs): + return self.impl.get_history(instance._state, **kwargs) def clause_element(self): return self.comparator.clause_element() def expression_element(self): return self.comparator.expression_element() - - def operate(self, op, other, **kwargs): - return op(self.comparator, other, **kwargs) + + def operate(self, op, *other, **kwargs): + return op(self.comparator, *other, **kwargs) def reverse_operate(self, op, other, **kwargs): return op(other, self.comparator, **kwargs) - - def hasparent(self, item, optimistic=False): + + def hasparent(self, instance, optimistic=False): + return self.impl.hasparent(instance._state, optimistic=optimistic) + + def _property(self): + from sqlalchemy.orm.mapper import class_mapper + return class_mapper(self.impl.class_).get_property(self.impl.key) + property = property(_property, doc="the MapperProperty object associated with this attribute") + +class ProxiedAttribute(InstrumentedAttribute): + """Adds InstrumentedAttribute class-level behavior to a regular descriptor. + + Obsoleted by proxied_attribute_factory. + """ + + class ProxyImpl(object): + accepts_scalar_loader = False + + def __init__(self, key): + self.key = key + + def __init__(self, key, user_prop, comparator=None): + self.user_prop = user_prop + self._comparator = comparator + self.key = key + self.impl = ProxiedAttribute.ProxyImpl(key) + + def comparator(self): + if callable(self._comparator): + self._comparator = self._comparator() + return self._comparator + comparator = property(comparator) + + def __get__(self, instance, owner): + if instance is None: + self.user_prop.__get__(instance, owner) + return self + return self.user_prop.__get__(instance, owner) + + def __set__(self, instance, value): + return self.user_prop.__set__(instance, value) + + def __delete__(self, instance): + return self.user_prop.__delete__(instance) + +def proxied_attribute_factory(descriptor): + """Create an InstrumentedAttribute / user descriptor hybrid. + + Returns a new InstrumentedAttribute type that delegates descriptor + behavior and getattr() to the given descriptor. + """ + + class ProxyImpl(object): + accepts_scalar_loader = False + def __init__(self, key): + self.key = key + + class Proxy(InstrumentedAttribute): + """A combination of InsturmentedAttribute and a regular descriptor.""" + + def __init__(self, key, descriptor, comparator): + self.key = key + # maintain ProxiedAttribute.user_prop compatability. + self.descriptor = self.user_prop = descriptor + self._comparator = comparator + self.impl = ProxyImpl(key) + + def comparator(self): + if callable(self._comparator): + self._comparator = self._comparator() + return self._comparator + comparator = property(comparator) + + def __get__(self, instance, owner): + """Delegate __get__ to the original descriptor.""" + if instance is None: + descriptor.__get__(instance, owner) + return self + return descriptor.__get__(instance, owner) + + def __set__(self, instance, value): + """Delegate __set__ to the original descriptor.""" + return descriptor.__set__(instance, value) + + def __delete__(self, instance): + """Delegate __delete__ to the original descriptor.""" + return descriptor.__delete__(instance) + + def __getattr__(self, attribute): + """Delegate __getattr__ to the original descriptor.""" + return getattr(descriptor, attribute) + Proxy.__name__ = type(descriptor).__name__ + 'Proxy' + + util.monkeypatch_proxied_specials(Proxy, type(descriptor), + name='descriptor', + from_instance=descriptor) + return Proxy + +class AttributeImpl(object): + """internal implementation for instrumented attributes.""" + + def __init__(self, class_, key, callable_, trackparent=False, extension=None, compare_function=None, **kwargs): + """Construct an AttributeImpl. + + class_ + the class to be instrumented. + + key + string name of the attribute + + callable_ + optional function which generates a callable based on a parent + instance, which produces the "default" values for a scalar or + collection attribute when it's first accessed, if not present + already. + + trackparent + if True, attempt to track if an instance has a parent attached + to it via this attribute. + + extension + an AttributeExtension object which will receive + set/delete/append/remove/etc. events. + + compare_function + a function that compares two values which are normally + assignable to this attribute. + + """ + + self.class_ = class_ + self.key = key + self.callable_ = callable_ + self.trackparent = trackparent + if compare_function is None: + self.is_equal = operator.eq + else: + self.is_equal = compare_function + self.extensions = util.to_list(extension or []) + + def hasparent(self, state, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. The `optimistic` flag determines what the default return value @@ -109,32 +212,17 @@ class InstrumentedAttribute(interfaces.PropComparator): will also not have a `hasparent` flag. """ - return item._state.get(('hasparent', id(self)), optimistic) + return state.parents.get(id(self), optimistic) - def sethasparent(self, item, value): + def sethasparent(self, state, value): """Set a boolean flag on the given item corresponding to whether or not it is attached to a parent object via the attribute represented by this ``InstrumentedAttribute``. """ - item._state[('hasparent', id(self))] = value - - def get_history(self, obj, passive=False): - """Return a new ``AttributeHistory`` object for the given object/this attribute's key. - - If `passive` is True, then don't execute any callables; if the - attribute's value can only be achieved via executing a - callable, then return None. - """ - - # get the current state. this may trigger a lazy load if - # passive is False. - current = self.get(obj, passive=passive) - if current is PASSIVE_NORESULT: - return None - return AttributeHistory(self, obj, current, passive=passive) + state.parents[id(self)] = value - def set_callable(self, obj, callable_): + def set_callable(self, state, callable_): """Set a callable function for this attribute on the given object. This callable will be executed when the attribute is next @@ -150,53 +238,28 @@ class InstrumentedAttribute(interfaces.PropComparator): """ if callable_ is None: - self.initialize(obj) + self.initialize(state) else: - obj._state[('callable', self)] = callable_ + state.callables[self.key] = callable_ + + def get_history(self, state, passive=False): + raise NotImplementedError() - def _get_callable(self, obj): - if ('callable', self) in obj._state: - return obj._state[('callable', self)] + def _get_callable(self, state): + if self.key in state.callables: + return state.callables[self.key] elif self.callable_ is not None: - return self.callable_(obj) + return self.callable_(state.obj()) else: return None - def reset(self, obj): - """Remove any per-instance callable functions corresponding to - this ``InstrumentedAttribute``'s attribute from the given - object, and remove this ``InstrumentedAttribute``'s attribute - from the given object's dictionary. - """ - - try: - del obj._state[('callable', self)] - except KeyError: - pass - self.clear(obj) - - def clear(self, obj): - """Remove this ``InstrumentedAttribute``'s attribute from the given object's dictionary. - - Subsequent calls to ``getattr(obj, key)`` will raise an - ``AttributeError`` by default. - """ - - try: - del obj.__dict__[self.key] - except KeyError: - pass - - def check_mutable_modified(self, obj): - return False - - def initialize(self, obj): + def initialize(self, state): """Initialize this attribute on the given object instance with an empty value.""" - obj.__dict__[self.key] = None + state.dict[self.key] = None return None - def get(self, obj, passive=False): + def get(self, state, passive=False): """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and @@ -205,131 +268,154 @@ class InstrumentedAttribute(interfaces.PropComparator): """ try: - return obj.__dict__[self.key] + return state.dict[self.key] except KeyError: - state = obj._state - # if an instance-wide "trigger" was set, call that - # and start again - if 'trigger' in state: - trig = state['trigger'] - del state['trigger'] - trig() - return self.get(obj, passive=passive) - - callable_ = self._get_callable(obj) - if callable_ is not None: - if passive: - return PASSIVE_NORESULT - self.logger.debug("Executing lazy callable on %s.%s" % - (orm_util.instance_str(obj), self.key)) - value = callable_() - if value is not ATTR_WAS_SET: - return self.set_committed_value(obj, value) - else: - return obj.__dict__[self.key] - else: - # Return a new, empty value - return self.initialize(obj) + # if no history, check for lazy callables, etc. + if self.key not in state.committed_state: + callable_ = self._get_callable(state) + if callable_ is not None: + if passive: + return PASSIVE_NORESULT + value = callable_() + if value is not ATTR_WAS_SET: + return self.set_committed_value(state, value) + else: + if self.key not in state.dict: + return self.get(state, passive=passive) + return state.dict[self.key] - def append(self, obj, value, initiator): - self.set(obj, value, initiator) + # Return a new, empty value + return self.initialize(state) - def remove(self, obj, value, initiator): - self.set(obj, None, initiator) + def append(self, state, value, initiator, passive=False): + self.set(state, value, initiator) - def set(self, obj, value, initiator): + def remove(self, state, value, initiator, passive=False): + self.set(state, None, initiator) + + def set(self, state, value, initiator): raise NotImplementedError() - def set_committed_value(self, obj, value): - """set an attribute value on the given instance and 'commit' it. - - this indicates that the given value is the "persisted" value, - and history will be logged only if a newly set value is not - equal to this value. - - this is typically used by deferred/lazy attribute loaders - to set object attributes after the initial load. - """ + def get_committed_value(self, state): + """return the unchanged value of this attribute""" - state = obj._state - orig = state.get('original', None) - if orig is not None: - orig.commit_attribute(self, obj, value) - # remove per-instance callable, if any - state.pop(('callable', self), None) - obj.__dict__[self.key] = value - return value + if self.key in state.committed_state: + if state.committed_state[self.key] is NO_VALUE: + return None + else: + return state.committed_state.get(self.key) + else: + return self.get(state) + + def set_committed_value(self, state, value): + """set an attribute value on the given instance and 'commit' it.""" - def set_raw_value(self, obj, value): - obj.__dict__[self.key] = value + state.commit_attr(self, value) return value - def fire_append_event(self, obj, value, initiator): - obj._state['modified'] = True - if self.trackparent and value is not None: - self.sethasparent(value, True) - for ext in self.extensions: - ext.append(obj, value, initiator or self) +class ScalarAttributeImpl(AttributeImpl): + """represents a scalar value-holding InstrumentedAttribute.""" - def fire_remove_event(self, obj, value, initiator): - obj._state['modified'] = True - if self.trackparent and value is not None: - self.sethasparent(value, False) - for ext in self.extensions: - ext.remove(obj, value, initiator or self) + accepts_scalar_loader = True - def fire_replace_event(self, obj, value, previous, initiator): - obj._state['modified'] = True - if self.trackparent: - if value is not None: - self.sethasparent(value, True) - if previous is not None: - self.sethasparent(previous, False) - for ext in self.extensions: - ext.set(obj, value, previous, initiator or self) + def delete(self, state): + if self.key not in state.committed_state: + state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE) - property = property(lambda s: class_mapper(s.class_).get_property(s.key), - doc="the MapperProperty object associated with this attribute") + # TODO: catch key errors, convert to attributeerror? + del state.dict[self.key] + state.modified=True -InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute) + def get_history(self, state, passive=False): + return _create_history(self, state, state.dict.get(self.key, NO_VALUE)) - -class InstrumentedScalarAttribute(InstrumentedAttribute): - """represents a scalar-holding InstrumentedAttribute.""" - - def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): - super(InstrumentedScalarAttribute, self).__init__(class_, manager, key, - callable_, trackparent=trackparent, extension=extension, - compare_function=compare_function, **kwargs) - self.mutable_scalars = mutable_scalars + def set(self, state, value, initiator): + if initiator is self: + return + + if self.key not in state.committed_state: + state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE) + + state.dict[self.key] = value + state.modified=True + def type(self): + self.property.columns[0].type + type = property(type) + +class MutableScalarAttributeImpl(ScalarAttributeImpl): + """represents a scalar value-holding InstrumentedAttribute, which can detect + changes within the value itself. + """ + + def __init__(self, class_, key, callable_, copy_function=None, compare_function=None, **kwargs): + super(ScalarAttributeImpl, self).__init__(class_, key, callable_, compare_function=compare_function, **kwargs) + class_._class_state.has_mutable_scalars = True if copy_function is None: - copy_function = self.__copy + raise exceptions.ArgumentError("MutableScalarAttributeImpl requires a copy function") self.copy = copy_function - def __copy(self, item): - # scalar values are assumed to be immutable unless a copy function - # is passed - return item - - def __delete__(self, obj): - old = self.get(obj) - del obj.__dict__[self.key] - self.fire_remove_event(obj, old, self) - - def check_mutable_modified(self, obj): - if self.mutable_scalars: - h = self.get_history(obj, passive=True) - if h is not None and h.is_modified(): - obj._state['modified'] = True - return True - else: - return False + def get_history(self, state, passive=False): + return _create_history(self, state, state.dict.get(self.key, NO_VALUE)) + + def commit_to_state(self, state, value): + state.committed_state[self.key] = self.copy(value) + + def check_mutable_modified(self, state): + (added, unchanged, deleted) = self.get_history(state, passive=True) + if added or deleted: + state.modified = True + return True else: return False - def set(self, obj, value, initiator): - """Set a value on the given object. + def set(self, state, value, initiator): + if initiator is self: + return + + if self.key not in state.committed_state: + if self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + else: + state.committed_state[self.key] = NO_VALUE + + state.dict[self.key] = value + state.modified=True + + +class ScalarObjectAttributeImpl(ScalarAttributeImpl): + """represents a scalar-holding InstrumentedAttribute, where the target object is also instrumented. + + Adds events to delete/set operations. + """ + + accepts_scalar_loader = False + + def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): + super(ScalarObjectAttributeImpl, self).__init__(class_, key, + callable_, trackparent=trackparent, extension=extension, + compare_function=compare_function, **kwargs) + if compare_function is None: + self.is_equal = identity_equal + + def delete(self, state): + old = self.get(state) + # TODO: catch key errors, convert to attributeerror? + del state.dict[self.key] + self.fire_remove_event(state, old, self) + + def get_history(self, state, passive=False): + if self.key in state.dict: + return _create_history(self, state, state.dict[self.key]) + else: + current = self.get(state, passive=passive) + if current is PASSIVE_NORESULT: + return (None, None, None) + else: + return _create_history(self, state, current) + + def set(self, state, value, initiator): + """Set a value on the given InstanceState. `initiator` is the ``InstrumentedAttribute`` that initiated the ``set()` operation and is used to control the depth of a circular @@ -339,31 +425,57 @@ class InstrumentedScalarAttribute(InstrumentedAttribute): if initiator is self: return - state = obj._state - # if an instance-wide "trigger" was set, call that - if 'trigger' in state: - trig = state['trigger'] - del state['trigger'] - trig() + if value is not None and not hasattr(value, '_state'): + raise TypeError("Can not assign %s instance to %s's %r attribute, " + "a mapped instance was expected." % ( + type(value).__name__, type(state.obj()).__name__, self.key)) - old = self.get(obj) - obj.__dict__[self.key] = value - self.fire_replace_event(obj, value, old, initiator) + # TODO: add options to allow the get() to be passive + old = self.get(state) + state.dict[self.key] = value + self.fire_replace_event(state, value, old, initiator) - type = property(lambda self: self.property.columns[0].type) + def fire_remove_event(self, state, value, initiator): + if self.key not in state.committed_state: + state.committed_state[self.key] = value + state.modified = True - -class InstrumentedCollectionAttribute(InstrumentedAttribute): + if self.trackparent and value is not None: + self.sethasparent(value._state, False) + + instance = state.obj() + for ext in self.extensions: + ext.remove(instance, value, initiator or self) + + def fire_replace_event(self, state, value, previous, initiator): + if self.key not in state.committed_state: + state.committed_state[self.key] = previous + state.modified = True + + if self.trackparent: + if value is not None: + self.sethasparent(value._state, True) + if previous is not value and previous is not None: + self.sethasparent(previous._state, False) + + instance = state.obj() + for ext in self.extensions: + ext.set(instance, value, previous, initiator or self) + +class CollectionAttributeImpl(AttributeImpl): """A collection-holding attribute that instruments changes in membership. + Only handles collections of instrumented objects. + InstrumentedCollectionAttribute holds an arbitrary, user-specified container object (defaulting to a list) and brokers access to the CollectionAdapter, a "view" onto that object that presents consistent bag semantics to the orm layer independent of the user data implementation. """ - - def __init__(self, class_, manager, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): - super(InstrumentedCollectionAttribute, self).__init__(class_, manager, + accepts_scalar_loader = False + + def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): + super(CollectionAttributeImpl, self).__init__(class_, key, callable_, trackparent=trackparent, extension=extension, compare_function=compare_function, **kwargs) @@ -375,59 +487,90 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): typecallable = list self.collection_factory = \ collections._prepare_instrumentation(typecallable) + # may be removed in 0.5: self.collection_interface = \ util.duck_type_collection(self.collection_factory()) def __copy(self, item): return [y for y in list(collections.collection_adapter(item))] - def __set__(self, obj, value): - """Replace the current collection with a new one.""" + def get_history(self, state, passive=False): + current = self.get(state, passive=passive) + if current is PASSIVE_NORESULT: + return (None, None, None) + else: + return _create_history(self, state, current) - setting_type = util.duck_type_collection(value) + def fire_append_event(self, state, value, initiator): + if self.key not in state.committed_state and self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) - if value is None or setting_type != self.collection_interface: - raise exceptions.ArgumentError( - "Incompatible collection type on assignment: %s is not %s-like" % - (type(value).__name__, self.collection_interface.__name__)) + state.modified = True - if hasattr(value, '_sa_adapter'): - self.set(obj, list(getattr(value, '_sa_adapter')), None) - elif setting_type == dict: - self.set(obj, value.values(), None) - else: - self.set(obj, value, None) + if self.trackparent and value is not None: + self.sethasparent(value._state, True) + instance = state.obj() + for ext in self.extensions: + ext.append(instance, value, initiator or self) + + def fire_pre_remove_event(self, state, initiator): + if self.key not in state.committed_state and self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + + def fire_remove_event(self, state, value, initiator): + if self.key not in state.committed_state and self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) - def __delete__(self, obj): - if self.key not in obj.__dict__: + state.modified = True + + if self.trackparent and value is not None: + self.sethasparent(value._state, False) + + instance = state.obj() + for ext in self.extensions: + ext.remove(instance, value, initiator or self) + + def delete(self, state): + if self.key not in state.dict: return - obj._state['modified'] = True + state.modified = True - collection = self._get_collection(obj) + collection = self.get_collection(state) collection.clear_with_event() - del obj.__dict__[self.key] + # TODO: catch key errors, convert to attributeerror? + del state.dict[self.key] - def initialize(self, obj): + def initialize(self, state): """Initialize this attribute on the given object instance with an empty collection.""" - _, user_data = self._build_collection(obj) - obj.__dict__[self.key] = user_data + _, user_data = self._build_collection(state) + state.dict[self.key] = user_data return user_data - def append(self, obj, value, initiator): + def append(self, state, value, initiator, passive=False): if initiator is self: return - collection = self._get_collection(obj) - collection.append_with_event(value, initiator) - def remove(self, obj, value, initiator): + collection = self.get_collection(state, passive=passive) + if collection is PASSIVE_NORESULT: + state.get_pending(self.key).append(value) + self.fire_append_event(state, value, initiator) + else: + collection.append_with_event(value, initiator) + + def remove(self, state, value, initiator, passive=False): if initiator is self: return - collection = self._get_collection(obj) - collection.remove_with_event(value, initiator) - def set(self, obj, value, initiator): + collection = self.get_collection(state, passive=passive) + if collection is PASSIVE_NORESULT: + state.get_pending(self.key).remove(value) + self.fire_remove_event(state, value, initiator) + else: + collection.remove_with_event(value, initiator) + + def set(self, state, value, initiator): """Set a value on the given object. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -438,72 +581,105 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): if initiator is self: return - state = obj._state - # if an instance-wide "trigger" was set, call that - if 'trigger' in state: - trig = state['trigger'] - del state['trigger'] - trig() - - old = self.get(obj) - old_collection = self._get_collection(obj, old) - - new_collection, user_data = self._build_collection(obj) - self._load_collection(obj, value or [], emit_events=True, - collection=new_collection) - - obj.__dict__[self.key] = user_data - state['modified'] = True - - # mark all the old elements as detached from the parent - if old_collection: - old_collection.clear_with_event() - old_collection.unlink(old) - - def set_committed_value(self, obj, value): - """Set an attribute value on the given instance and 'commit' it.""" - - state = obj._state - orig = state.get('original', None) - - collection, user_data = self._build_collection(obj) - self._load_collection(obj, value or [], emit_events=False, - collection=collection) - value = user_data - - if orig is not None: - orig.commit_attribute(self, obj, value) - # remove per-instance callable, if any - state.pop(('callable', self), None) - obj.__dict__[self.key] = value - return value + self._set_iterable( + state, value, + lambda adapter, i: adapter.adapt_like_to_iterable(i)) - def _build_collection(self, obj): - user_data = self.collection_factory() - collection = collections.CollectionAdapter(self, obj, user_data) - return collection, user_data + def _set_iterable(self, state, iterable, adapter=None): + """Set a collection value from an iterable of state-bearers. - def _load_collection(self, obj, values, emit_events=True, collection=None): - collection = collection or self._get_collection(obj) - if values is None: - return - elif emit_events: - for item in values: - collection.append_with_event(item) + ``adapter`` is an optional callable invoked with a CollectionAdapter + and the iterable. Should return an iterable of state-bearing + instances suitable for appending via a CollectionAdapter. Can be used + for, e.g., adapting an incoming dictionary into an iterator of values + rather than keys. + + """ + # pulling a new collection first so that an adaptation exception does + # not trigger a lazy load of the old collection. + new_collection, user_data = self._build_collection(state) + if adapter: + new_values = list(adapter(new_collection, iterable)) else: - for item in values: + new_values = list(iterable) + + old = self.get(state) + + # ignore re-assignment of the current collection, as happens + # implicitly with in-place operators (foo.collection |= other) + if old is iterable: + return + + if self.key not in state.committed_state: + state.committed_state[self.key] = self.copy(old) + + old_collection = self.get_collection(state, old) + + state.dict[self.key] = user_data + state.modified = True + + collections.bulk_replace(new_values, old_collection, new_collection) + old_collection.unlink(old) + + + def set_committed_value(self, state, value): + """Set an attribute value on the given instance and 'commit' it. + + Loads the existing collection from lazy callables in all cases. + """ + + collection, user_data = self._build_collection(state) + + if value: + for item in value: collection.append_without_event(item) - - def _get_collection(self, obj, user_data=None): + + state.callables.pop(self.key, None) + state.dict[self.key] = user_data + + if self.key in state.pending: + # pending items. commit loaded data, add/remove new data + state.committed_state[self.key] = list(value or []) + added = state.pending[self.key].added_items + removed = state.pending[self.key].deleted_items + for item in added: + collection.append_without_event(item) + for item in removed: + collection.remove_without_event(item) + del state.pending[self.key] + elif self.key in state.committed_state: + # no pending items. remove committed state if any. + # (this can occur with an expired attribute) + del state.committed_state[self.key] + + return user_data + + def _build_collection(self, state): + """build a new, blank collection and return it wrapped in a CollectionAdapter.""" + + user_data = self.collection_factory() + collection = collections.CollectionAdapter(self, state, user_data) + return collection, user_data + + def get_collection(self, state, user_data=None, passive=False): + """retrieve the CollectionAdapter associated with the given state. + + Creates a new CollectionAdapter if one does not exist. + + """ + if user_data is None: - user_data = self.get(obj) + user_data = self.get(state, passive=passive) + if user_data is PASSIVE_NORESULT: + return user_data try: return getattr(user_data, '_sa_adapter') except AttributeError: - collections.CollectionAdapter(self, obj, user_data) + # TODO: this codepath never occurs, and this + # except/initialize should be removed + collections.CollectionAdapter(self, state, user_data) return getattr(user_data, '_sa_adapter') - class GenericBackrefExtension(interfaces.AttributeExtension): """An extension which synchronizes a two-way relationship. @@ -516,365 +692,604 @@ class GenericBackrefExtension(interfaces.AttributeExtension): def __init__(self, key): self.key = key - def set(self, obj, child, oldchild, initiator): + def set(self, instance, child, oldchild, initiator): if oldchild is child: return if oldchild is not None: - getattr(oldchild.__class__, self.key).remove(oldchild, obj, initiator) + # With lazy=None, there's no guarantee that the full collection is + # present when updating via a backref. + impl = getattr(oldchild.__class__, self.key).impl + try: + impl.remove(oldchild._state, instance, initiator, passive=True) + except (ValueError, KeyError, IndexError): + pass if child is not None: - getattr(child.__class__, self.key).append(child, obj, initiator) + getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True) - def append(self, obj, child, initiator): - getattr(child.__class__, self.key).append(child, obj, initiator) + def append(self, instance, child, initiator): + getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True) - def remove(self, obj, child, initiator): - getattr(child.__class__, self.key).remove(child, obj, initiator) + def remove(self, instance, child, initiator): + if child is not None: + getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator, passive=True) -class CommittedState(object): - """Store the original state of an object when the ``commit()` - method on the attribute manager is called. - """ +class ClassState(object): + """tracks state information at the class level.""" + def __init__(self): + self.mappers = {} + self.attrs = {} + self.has_mutable_scalars = False + +import sets +_empty_set = sets.ImmutableSet() + +class InstanceState(object): + """tracks state information at the instance level.""" + + def __init__(self, obj): + self.class_ = obj.__class__ + self.obj = weakref.ref(obj, self.__cleanup) + self.dict = obj.__dict__ + self.committed_state = {} + self.modified = False + self.callables = {} + self.parents = {} + self.pending = {} + self.appenders = {} + self.instance_dict = None + self.runid = None + self.expired_attributes = _empty_set + + def __cleanup(self, ref): + # tiptoe around Python GC unpredictableness + instance_dict = self.instance_dict + if instance_dict is None: + return - NO_VALUE = object() + instance_dict = instance_dict() + if instance_dict is None or instance_dict._mutex is None: + return - def __init__(self, manager, obj): - self.data = {} - for attr in manager.managed_attributes(obj.__class__): - self.commit_attribute(attr, obj) + # the mutexing here is based on the assumption that gc.collect() + # may be firing off cleanup handlers in a different thread than that + # which is normally operating upon the instance dict. + instance_dict._mutex.acquire() + try: + try: + self.__resurrect(instance_dict) + except: + # catch app cleanup exceptions. no other way around this + # without warnings being produced + pass + finally: + instance_dict._mutex.release() - def commit_attribute(self, attr, obj, value=NO_VALUE): - """Establish the value of attribute `attr` on instance `obj` - as *committed*. + def _check_resurrect(self, instance_dict): + instance_dict._mutex.acquire() + try: + return self.obj() or self.__resurrect(instance_dict) + finally: + instance_dict._mutex.release() - This corresponds to a previously saved state being restored. - """ + def get_pending(self, key): + if key not in self.pending: + self.pending[key] = PendingCollection() + return self.pending[key] - if value is CommittedState.NO_VALUE: - if attr.key in obj.__dict__: - value = obj.__dict__[attr.key] - if value is not CommittedState.NO_VALUE: - self.data[attr.key] = attr.copy(value) - - # not tracking parent on lazy-loaded instances at the moment. - # its not needed since they will be "optimistically" tested - #if attr.uselist: - #if attr.trackparent: - # [attr.sethasparent(x, True) for x in self.data[attr.key] if x is not None] - #else: - #if attr.trackparent and value is not None: - # attr.sethasparent(value, True) - - def rollback(self, manager, obj): - for attr in manager.managed_attributes(obj.__class__): - if self.data.has_key(attr.key): - if not isinstance(attr, InstrumentedCollectionAttribute): - obj.__dict__[attr.key] = self.data[attr.key] - else: - collection = attr._get_collection(obj) - collection.clear_without_event() - for item in self.data[attr.key]: - collection.append_without_event(item) + def is_modified(self): + if self.modified: + return True + elif self.class_._class_state.has_mutable_scalars: + for attr in _managed_attributes(self.class_): + if hasattr(attr.impl, 'check_mutable_modified') and attr.impl.check_mutable_modified(self): + return True else: - del obj.__dict__[attr.key] - - def __repr__(self): - return "CommittedState: %s" % repr(self.data) - -class AttributeHistory(object): - """Calculate the *history* of a particular attribute on a - particular instance, based on the ``CommittedState`` associated - with the instance, if any. - """ - - def __init__(self, attr, obj, current, passive=False): - self.attr = attr - - # get the "original" value. if a lazy load was fired when we got - # the 'current' value, this "original" was also populated just - # now as well (therefore we have to get it second) - orig = obj._state.get('original', None) - if orig is not None: - original = orig.data.get(attr.key) - else: - original = None - - if isinstance(attr, InstrumentedCollectionAttribute): - self._current = current - s = util.Set(original or []) - self._added_items = [] - self._unchanged_items = [] - self._deleted_items = [] - if current: - collection = attr._get_collection(obj, current) - for a in collection: - if a in s: - self._unchanged_items.append(a) - else: - self._added_items.append(a) - for a in s: - if a not in self._unchanged_items: - self._deleted_items.append(a) + return False else: - self._current = [current] - if attr.is_equal(current, original): - self._unchanged_items = [current] - self._added_items = [] - self._deleted_items = [] - else: - self._added_items = [current] - if original is not None: - self._deleted_items = [original] - else: - self._deleted_items = [] - self._unchanged_items = [] - - def __iter__(self): - return iter(self._current) - - def is_modified(self): - return len(self._deleted_items) > 0 or len(self._added_items) > 0 - - def added_items(self): - return self._added_items - - def unchanged_items(self): - return self._unchanged_items + return False - def deleted_items(self): - return self._deleted_items + def __resurrect(self, instance_dict): + if self.is_modified(): + # store strong ref'ed version of the object; will revert + # to weakref when changes are persisted + obj = new_instance(self.class_, state=self) + self.obj = weakref.ref(obj, self.__cleanup) + self._strong_obj = obj + obj.__dict__.update(self.dict) + self.dict = obj.__dict__ + return obj + else: + del instance_dict[self.dict['_instance_key']] + return None - def hasparent(self, obj): - """Deprecated. This should be called directly from the appropriate ``InstrumentedAttribute`` object. + def __getstate__(self): + return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':self.expired_attributes, 'callables':self.callables} + + def __setstate__(self, state): + self.committed_state = state['committed_state'] + self.parents = state['parents'] + self.pending = state['pending'] + self.modified = state['modified'] + self.obj = weakref.ref(state['instance']) + self.class_ = self.obj().__class__ + self.dict = self.obj().__dict__ + self.callables = state['callables'] + self.runid = None + self.appenders = {} + self.expired_attributes = state['expired_attributes'] + + def initialize(self, key): + getattr(self.class_, key).impl.initialize(self) + + def set_callable(self, key, callable_): + self.dict.pop(key, None) + self.callables[key] = callable_ + + def __call__(self): + """__call__ allows the InstanceState to act as a deferred + callable for loading expired attributes, which is also + serializable. """ + instance = self.obj() + unmodified = self.unmodified + self.class_._class_state.deferred_scalar_loader(instance, [ + attr.impl.key for attr in _managed_attributes(self.class_) if + attr.impl.accepts_scalar_loader and + attr.impl.key in self.expired_attributes and + attr.impl.key in unmodified + ]) + for k in self.expired_attributes: + self.callables.pop(k, None) + self.expired_attributes.clear() + return ATTR_WAS_SET + + def unmodified(self): + """a set of keys which have no uncommitted changes""" + + return util.Set([ + attr.impl.key for attr in _managed_attributes(self.class_) if + attr.impl.key not in self.committed_state + and (not hasattr(attr.impl, 'commit_to_state') or not attr.impl.check_mutable_modified(self)) + ]) + unmodified = property(unmodified) + + def expire_attributes(self, attribute_names): + self.expired_attributes = util.Set(self.expired_attributes) + + if attribute_names is None: + for attr in _managed_attributes(self.class_): + self.dict.pop(attr.impl.key, None) + self.expired_attributes.add(attr.impl.key) + if attr.impl.accepts_scalar_loader: + self.callables[attr.impl.key] = self + + self.committed_state = {} + else: + for key in attribute_names: + self.dict.pop(key, None) + self.committed_state.pop(key, None) + self.expired_attributes.add(key) + if getattr(self.class_, key).impl.accepts_scalar_loader: + self.callables[key] = self + + def reset(self, key): + """remove the given attribute and any callables associated with it.""" + self.dict.pop(key, None) + self.callables.pop(key, None) + + def commit_attr(self, attr, value): + """set the value of an attribute and mark it 'committed'.""" + + if hasattr(attr, 'commit_to_state'): + attr.commit_to_state(self, value) + else: + self.committed_state.pop(attr.key, None) + self.dict[attr.key] = value + self.pending.pop(attr.key, None) + self.appenders.pop(attr.key, None) - return self.attr.hasparent(obj) + # we have a value so we can also unexpire it + self.callables.pop(attr.key, None) + if attr.key in self.expired_attributes: + self.expired_attributes.remove(attr.key) -class AttributeManager(object): - """Allow the instrumentation of object attributes.""" + def commit(self, keys): + """commit all attributes named in the given list of key names. - def __init__(self): - # will cache attributes, indexed by class objects - self._inherited_attribute_cache = weakref.WeakKeyDictionary() - self._noninherited_attribute_cache = weakref.WeakKeyDictionary() + This is used by a partial-attribute load operation to mark committed those attributes + which were refreshed from the database. - def clear_attribute_cache(self): - self._attribute_cache.clear() - - def rollback(self, *obj): - """Retrieve the committed history for each object in the given - list, and rolls back the attributes each instance to their - original value. + Attributes marked as "expired" can potentially remain "expired" after this step + if a value was not populated in state.dict. """ - for o in obj: - orig = o._state.get('original') - if orig is not None: - orig.rollback(self, o) - else: - self._clear(o) + if self.class_._class_state.has_mutable_scalars: + for key in keys: + attr = getattr(self.class_, key).impl + if hasattr(attr, 'commit_to_state') and attr.key in self.dict: + attr.commit_to_state(self, self.dict[attr.key]) + else: + self.committed_state.pop(attr.key, None) + self.pending.pop(key, None) + self.appenders.pop(key, None) + else: + for key in keys: + self.committed_state.pop(key, None) + self.pending.pop(key, None) + self.appenders.pop(key, None) - def _clear(self, obj): - for attr in self.managed_attributes(obj.__class__): - try: - del obj.__dict__[attr.key] - except KeyError: - pass + # unexpire attributes which have loaded + for key in self.expired_attributes.intersection(keys): + if key in self.dict: + self.expired_attributes.remove(key) + self.callables.pop(key, None) - def commit(self, *obj): - """Create a ``CommittedState`` instance for each object in the given list, representing - its *unchanged* state, and associates it with the instance. - ``AttributeHistory`` objects will indicate the modified state of - instance attributes as compared to its value in this - ``CommittedState`` object. - """ + def commit_all(self): + """commit all attributes unconditionally. - for o in obj: - o._state['original'] = CommittedState(self, o) - o._state['modified'] = False + This is used after a flush() or a regular instance load or refresh operation + to mark committed all populated attributes. - def managed_attributes(self, class_): - """Return a list of all ``InstrumentedAttribute`` objects - associated with the given class. + Attributes marked as "expired" can potentially remain "expired" after this step + if a value was not populated in state.dict. """ + self.committed_state = {} + self.modified = False + self.pending = {} + self.appenders = {} + + # unexpire attributes which have loaded + for key in list(self.expired_attributes): + if key in self.dict: + self.expired_attributes.remove(key) + self.callables.pop(key, None) + + if self.class_._class_state.has_mutable_scalars: + for attr in _managed_attributes(self.class_): + if hasattr(attr.impl, 'commit_to_state') and attr.impl.key in self.dict: + attr.impl.commit_to_state(self, self.dict[attr.impl.key]) + + # remove strong ref + self._strong_obj = None + + +class WeakInstanceDict(UserDict.UserDict): + """similar to WeakValueDictionary, but wired towards 'state' objects.""" + + def __init__(self, *args, **kw): + self._wr = weakref.ref(self) + # RLock because the mutex is used by a cleanup handler, which can be + # called at any time (including within an already mutexed block) + self._mutex = util.threading.RLock() + UserDict.UserDict.__init__(self, *args, **kw) + + def __getitem__(self, key): + state = self.data[key] + o = state.obj() + if o is None: + o = state._check_resurrect(self) + if o is None: + raise KeyError, key + return o + + def __contains__(self, key): try: - return self._inherited_attribute_cache[class_] + state = self.data[key] + o = state.obj() + if o is None: + o = state._check_resurrect(self) except KeyError: - if not isinstance(class_, type): - raise TypeError(repr(class_) + " is not a type") - inherited = [v for v in [getattr(class_, key, None) for key in dir(class_)] if isinstance(v, InstrumentedAttribute)] - self._inherited_attribute_cache[class_] = inherited - return inherited + return False + return o is not None + + def has_key(self, key): + return key in self - def noninherited_managed_attributes(self, class_): + def __repr__(self): + return "" % id(self) + + def __setitem__(self, key, value): + if key in self.data: + self._mutex.acquire() + try: + if key in self.data: + self.data[key].instance_dict = None + finally: + self._mutex.release() + self.data[key] = value._state + value._state.instance_dict = self._wr + + def __delitem__(self, key): + state = self.data[key] + state.instance_dict = None + del self.data[key] + + def get(self, key, default=None): try: - return self._noninherited_attribute_cache[class_] + state = self.data[key] except KeyError: - if not isinstance(class_, type): - raise TypeError(repr(class_) + " is not a type") - noninherited = [v for v in [getattr(class_, key, None) for key in list(class_.__dict__)] if isinstance(v, InstrumentedAttribute)] - self._noninherited_attribute_cache[class_] = noninherited - return noninherited - - def is_modified(self, object): - for attr in self.managed_attributes(object.__class__): - if attr.check_mutable_modified(object): - return True - return object._state.get('modified', False) - - def init_attr(self, obj): - """Sets up the __sa_attr_state dictionary on the given instance. - - This dictionary is automatically created when the `_state` - attribute of the class is first accessed, but calling it here - will save a single throw of an ``AttributeError`` that occurs - in that creation step. - """ + return default + else: + o = state.obj() + if o is None: + # This should only happen + return default + else: + return o + + def items(self): + L = [] + for key, state in self.data.items(): + o = state.obj() + if o is not None: + L.append((key, o)) + return L + + def iteritems(self): + for state in self.data.itervalues(): + value = state.obj() + if value is not None: + yield value._instance_key, value - setattr(obj, '_%s__sa_attr_state' % obj.__class__.__name__, {}) + def iterkeys(self): + return self.data.iterkeys() - def get_history(self, obj, key, **kwargs): - """Return a new ``AttributeHistory`` object for the given - attribute on the given object. - """ + def __iter__(self): + return self.data.iterkeys() + + def __len__(self): + return len(self.values()) + + def itervalues(self): + for state in self.data.itervalues(): + instance = state.obj() + if instance is not None: + yield instance + + def values(self): + L = [] + for state in self.data.values(): + o = state.obj() + if o is not None: + L.append(o) + return L + + def popitem(self): + raise NotImplementedError() - return getattr(obj.__class__, key).get_history(obj, **kwargs) + def pop(self, key, *args): + raise NotImplementedError() - def get_as_list(self, obj, key, passive=False): - """Return an attribute of the given name from the given object. + def setdefault(self, key, default=None): + raise NotImplementedError() - If the attribute is a scalar, return it as a single-item list, - otherwise return a collection based attribute. + def update(self, dict=None, **kwargs): + raise NotImplementedError() - If the attribute's value is to be produced by an unexecuted - callable, the callable will only be executed if the given - `passive` flag is False. - """ + def copy(self): + raise NotImplementedError() - attr = getattr(obj.__class__, key) - x = attr.get(obj, passive=passive) - if x is PASSIVE_NORESULT: - return [] - elif isinstance(attr, InstrumentedCollectionAttribute): - return list(attr._get_collection(obj, x)) - else: - return [x] + def all_states(self): + return self.data.values() - def trigger_history(self, obj, callable): - """Clear all managed object attributes and places the given - `callable` as an attribute-wide *trigger*, which will execute - upon the next attribute access, after which the trigger is - removed. - """ +class StrongInstanceDict(dict): + def all_states(self): + return [o._state for o in self.values()] - self._clear(obj) - try: - del obj._state['original'] - except KeyError: - pass - obj._state['trigger'] = callable +def _create_history(attr, state, current): + original = state.committed_state.get(attr.key, NEVER_SET) - def untrigger_history(self, obj): - """Remove a trigger function set by trigger_history. + if hasattr(attr, 'get_collection'): + current = attr.get_collection(state, current) + if original is NO_VALUE: + return (list(current), [], []) + elif original is NEVER_SET: + return ([], list(current), []) + else: + collection = util.OrderedIdentitySet(current) + s = util.OrderedIdentitySet(original) + return (list(collection.difference(s)), list(collection.intersection(s)), list(s.difference(collection))) + else: + if current is NO_VALUE: + if original not in [None, NEVER_SET, NO_VALUE]: + deleted = [original] + else: + deleted = [] + return ([], [], deleted) + elif original is NO_VALUE: + return ([current], [], []) + elif original is NEVER_SET or attr.is_equal(current, original) is True: # dont let ClauseElement expressions here trip things up + return ([], [current], []) + else: + if original is not None: + deleted = [original] + else: + deleted = [] + return ([current], [], deleted) - Does not restore the previous state of the object. - """ +class PendingCollection(object): + """stores items appended and removed from a collection that has not been loaded yet. - del obj._state['trigger'] + When the collection is loaded, the changes present in PendingCollection are applied + to produce the final result. + """ - def has_trigger(self, obj): - """Return True if the given object has a trigger function set - by ``trigger_history()``. - """ + def __init__(self): + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() - return 'trigger' in obj._state + def append(self, value): + if value in self.deleted_items: + self.deleted_items.remove(value) + self.added_items.add(value) - def reset_instance_attribute(self, obj, key): - """Remove any per-instance callable functions corresponding to - given attribute `key` from the given object, and remove this - attribute from the given object's dictionary. - """ + def remove(self, value): + if value in self.added_items: + self.added_items.remove(value) + self.deleted_items.add(value) - attr = getattr(obj.__class__, key) - attr.reset(obj) +def _managed_attributes(class_): + """return all InstrumentedAttributes associated with the given class_ and its superclasses.""" - def reset_class_managed(self, class_): - """Remove all ``InstrumentedAttribute`` property objects from - the given class. - """ + return chain(*[cl._class_state.attrs.values() for cl in class_.__mro__[:-1] if hasattr(cl, '_class_state')]) - for attr in self.noninherited_managed_attributes(class_): - delattr(class_, attr.key) - self._inherited_attribute_cache.pop(class_,None) - self._noninherited_attribute_cache.pop(class_,None) +def get_history(state, key, **kwargs): + return getattr(state.class_, key).impl.get_history(state, **kwargs) - def is_class_managed(self, class_, key): - """Return True if the given `key` correponds to an - instrumented property on the given class. - """ - return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute) +def get_as_list(state, key, passive=False): + """return an InstanceState attribute as a list, + regardless of it being a scalar or collection-based + attribute. - def init_instance_attribute(self, obj, key, callable_=None): - """Initialize an attribute on an instance to either a blank - value, cancelling out any class- or instance-level callables - that were present, or if a `callable` is supplied set the - callable to be invoked when the attribute is next accessed. - """ + returns None if passive=True and the getter returns + PASSIVE_NORESULT. + """ - getattr(obj.__class__, key).set_callable(obj, callable_) + attr = getattr(state.class_, key).impl + x = attr.get(state, passive=passive) + if x is PASSIVE_NORESULT: + return None + elif hasattr(attr, 'get_collection'): + return attr.get_collection(state, x, passive=passive) + elif isinstance(x, list): + return x + else: + return [x] + +def has_parent(class_, instance, key, optimistic=False): + return getattr(class_, key).impl.hasparent(instance._state, optimistic=optimistic) + +def _create_prop(class_, key, uselist, callable_, typecallable, useobject, mutable_scalars, impl_class, **kwargs): + if impl_class: + return impl_class(class_, key, typecallable, **kwargs) + elif uselist: + return CollectionAttributeImpl(class_, key, callable_, typecallable, **kwargs) + elif useobject: + return ScalarObjectAttributeImpl(class_, key, callable_,**kwargs) + elif mutable_scalars: + return MutableScalarAttributeImpl(class_, key, callable_, **kwargs) + else: + return ScalarAttributeImpl(class_, key, callable_, **kwargs) + +def manage(instance): + """initialize an InstanceState on the given instance.""" + + if not hasattr(instance, '_state'): + instance._state = InstanceState(instance) + +def new_instance(class_, state=None): + """create a new instance of class_ without its __init__() method being called. + + Also initializes an InstanceState on the new instance. + """ - def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs): - """Create a scalar property object, defaulting to - ``InstrumentedAttribute``, which will communicate change - events back to this ``AttributeManager``. - """ + s = class_.__new__(class_) + if state: + s._state = state + else: + s._state = InstanceState(s) + return s - if uselist: - return InstrumentedCollectionAttribute(class_, self, key, - callable_, - typecallable, - **kwargs) - else: - return InstrumentedScalarAttribute(class_, self, key, callable_, - **kwargs) +def _init_class_state(class_): + if not '_class_state' in class_.__dict__: + class_._class_state = ClassState() - def get_attribute(self, obj_or_cls, key): - """Register an attribute at the class level to be instrumented - for all instances of the class. - """ +def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_loader=None): + _init_class_state(class_) + class_._class_state.deferred_scalar_loader=deferred_scalar_loader - if isinstance(obj_or_cls, type): - return getattr(obj_or_cls, key) - else: - return getattr(obj_or_cls.__class__, key) + oldinit = None + doinit = False - def register_attribute(self, class_, key, uselist, callable_=None, **kwargs): - """Register an attribute at the class level to be instrumented - for all instances of the class. - """ + def init(instance, *args, **kwargs): + if not hasattr(instance, '_state'): + instance._state = InstanceState(instance) + + if extra_init: + extra_init(class_, oldinit, instance, args, kwargs) + + try: + if doinit: + oldinit(instance, *args, **kwargs) + elif args or kwargs: + # simulate error message raised by object(), but don't copy + # the text verbatim + raise TypeError("default constructor for object() takes no parameters") + except: + if on_exception: + on_exception(class_, oldinit, instance, args, kwargs) + raise + + + # override oldinit + oldinit = class_.__init__ + if oldinit is None or not hasattr(oldinit, '_oldinit'): + init._oldinit = oldinit + class_.__init__ = init + # if oldinit is already one of our 'init' methods, replace it + elif hasattr(oldinit, '_oldinit'): + init._oldinit = oldinit._oldinit + class_.__init = init + oldinit = oldinit._oldinit + + if oldinit is not None: + doinit = oldinit is not object.__init__ + try: + init.__name__ = oldinit.__name__ + init.__doc__ = oldinit.__doc__ + except: + # cant set __name__ in py 2.3 ! + pass - # firt invalidate the cache for the given class - # (will be reconstituted as needed, while getting managed attributes) - self._inherited_attribute_cache.pop(class_, None) - self._noninherited_attribute_cache.pop(class_, None) - - if not hasattr(class_, '_state'): - def _get_state(self): - if not hasattr(self, '_sa_attr_state'): - self._sa_attr_state = {} - return self._sa_attr_state - class_._state = property(_get_state) - - typecallable = kwargs.pop('typecallable', None) - if isinstance(typecallable, InstrumentedAttribute): - typecallable = None - setattr(class_, key, self.create_prop(class_, key, uselist, callable_, - typecallable=typecallable, **kwargs)) - - def init_collection(self, instance, key): - """Initialize a collection attribute and return the collection adapter.""" - - attr = self.get_attribute(instance, key) - user_data = attr.initialize(instance) - return attr._get_collection(instance, user_data) +def unregister_class(class_): + if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): + if class_.__init__._oldinit is not None: + class_.__init__ = class_.__init__._oldinit + else: + delattr(class_, '__init__') + + if '_class_state' in class_.__dict__: + _class_state = class_.__dict__['_class_state'] + for key, attr in _class_state.attrs.iteritems(): + if key in class_.__dict__: + delattr(class_, attr.impl.key) + delattr(class_, '_class_state') + +def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, mutable_scalars=False, impl_class=None, **kwargs): + _init_class_state(class_) + + typecallable = kwargs.pop('typecallable', None) + if isinstance(typecallable, InstrumentedAttribute): + typecallable = None + comparator = kwargs.pop('comparator', None) + + if key in class_.__dict__ and isinstance(class_.__dict__[key], InstrumentedAttribute): + # this currently only occurs if two primary mappers are made for the same class. + # TODO: possibly have InstrumentedAttribute check "entity_name" when searching for impl. + # raise an error if two attrs attached simultaneously otherwise + return + + if proxy_property: + proxy_type = proxied_attribute_factory(proxy_property) + inst = proxy_type(key, proxy_property, comparator) + else: + inst = InstrumentedAttribute(_create_prop(class_, key, uselist, callable_, useobject=useobject, + typecallable=typecallable, mutable_scalars=mutable_scalars, impl_class=impl_class, **kwargs), comparator=comparator) + + setattr(class_, key, inst) + class_._class_state.attrs[key] = inst + +def unregister_attribute(class_, key): + class_state = class_._class_state + if key in class_state.attrs: + del class_._class_state.attrs[key] + delattr(class_, key) + +def init_collection(instance, key): + """Initialize a collection attribute and return the collection adapter.""" + attr = getattr(instance.__class__, key).impl + state = instance._state + user_data = attr.initialize(state) + return attr.get_collection(state, user_data) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 7ade882f5c..c8fc2f189a 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -15,11 +15,11 @@ and return values to events:: from sqlalchemy.orm.collections import collection class MyClass(object): # ... - + @collection.adds(1) def store(self, item): self.data.append(item) - + @collection.removes_return() def pop(self): return self.data.pop() @@ -31,7 +31,7 @@ standard Python ``list``, ``set`` and ``dict`` interfaces. These could be specified in terms of generic decorator recipes, but are instead hand-tooled for increased efficiency. The targeted decorators occasionally implement adapter-like behavior, such as mapping bulk-set methods (``extend``, ``update``, -``__setslice``, etc.) into the series of atomic mutation events that the ORM +``__setslice__``, etc.) into the series of atomic mutation events that the ORM requires. The targeted decorators are used internally for automatic instrumentation of @@ -95,26 +95,20 @@ The owning object and InstrumentedCollectionAttribute are also reachable through the adapter, allowing for some very sophisticated behavior. """ -import copy, inspect, sys, weakref +import copy +import inspect +import sets +import sys +import weakref from sqlalchemy import exceptions, schema, util as sautil -from sqlalchemy.orm import mapper - -try: - from threading import Lock -except: - from dummy_threading import Lock -try: - from operator import attrgetter -except: - def attrgetter(attribute): - return lambda value: getattr(value, attribute) +from sqlalchemy.util import attrgetter, Set __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] - + def column_mapped_collection(mapping_spec): """A dictionary-based collection type with column-based keying. @@ -127,10 +121,12 @@ def column_mapped_collection(mapping_spec): after a session flush. """ + from sqlalchemy.orm import object_mapper + if isinstance(mapping_spec, schema.Column): def keyfunc(value): - m = mapper.object_mapper(value) - return m.get_attr_by_column(value, mapping_spec) + m = object_mapper(value) + return m._get_attr_by_column(value, mapping_spec) else: cols = [] for c in mapping_spec: @@ -140,8 +136,8 @@ def column_mapped_collection(mapping_spec): cols.append(c) mapping_spec = tuple(cols) def keyfunc(value): - m = mapper.object_mapper(value) - return tuple([m.get_attr_by_column(value, c) for c in mapping_spec]) + m = object_mapper(value) + return tuple([m._get_attr_by_column(value, c) for c in mapping_spec]) return lambda: MappedCollection(keyfunc) def attribute_mapped_collection(attr_name): @@ -201,7 +197,7 @@ class collection(object): # Bundled as a class solely for ease of use: packaging, doc strings, # importability. - + def appender(cls, fn): """Tag the method as the collection appender. @@ -236,7 +232,7 @@ class collection(object): database contains rows that violate your collection semantics, you will need to get creative to fix the problem, as access via the collection will not work. - + If the appender method is internally instrumented, you must also receive the keyword argument '_sa_initiator' and ensure its promulgation to collection events. @@ -268,7 +264,7 @@ class collection(object): receive the keyword argument '_sa_initiator' and ensure its promulgation to collection events. """ - + setattr(fn, '_sa_instrument_role', 'remover') return fn remover = classmethod(remover) @@ -302,7 +298,7 @@ class collection(object): @collection.internally_instrumented def extend(self, items): ... """ - + setattr(fn, '_sa_instrumented', True) return fn internally_instrumented = classmethod(internally_instrumented) @@ -316,11 +312,44 @@ class collection(object): the instance. A single argument is passed: the collection adapter that has been linked, or None if unlinking. """ - + setattr(fn, '_sa_instrument_role', 'on_link') return fn on_link = classmethod(on_link) + def converter(cls, fn): + """Tag the method as the collection converter. + + This optional method will be called when a collection is being + replaced entirely, as in:: + + myobj.acollection = [newvalue1, newvalue2] + + The converter method will receive the object being assigned and should + return an iterable of values suitable for use by the ``appender`` + method. A converter must not assign values or mutate the collection, + it's sole job is to adapt the value the user provides into an iterable + of values for the ORM's use. + + The default converter implementation will use duck-typing to do the + conversion. A dict-like collection will be convert into an iterable + of dictionary values, and other types will simply be iterated. + + @collection.converter + def convert(self, other): ... + + If the duck-typing of the object does not match the type of this + collection, a TypeError is raised. + + Supply an implementation of this method if you want to expand the + range of possible types that can be assigned in bulk or perform + validation on the values about to be assigned. + """ + + setattr(fn, '_sa_instrument_role', 'converter') + return fn + converter = classmethod(converter) + def adds(cls, arg): """Mark the method as adding an entity to the collection. @@ -348,13 +377,13 @@ class collection(object): the method. The decorator argument indicates which method argument holds the SQLAlchemy-relevant value to be added, and return value, if any will be considered the value to remove. - + Arguments can be specified positionally (i.e. integer) or by name:: @collection.replaces(2) def __setitem__(self, index, item): ... """ - + def decorator(fn): setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) setattr(fn, '_sa_instrument_after', 'fire_remove_event') @@ -382,7 +411,7 @@ class collection(object): return fn return decorator removes = classmethod(removes) - + def removes_return(cls): """Mark the method as removing an entity in the collection. @@ -411,6 +440,21 @@ def collection_adapter(collection): return getattr(collection, '_sa_adapter', None) +def collection_iter(collection): + """Iterate over an object supporting the @iterator or __iter__ protocols. + + If the collection is an ORM collection, it need not be attached to an + object to be iterable. + """ + + try: + return getattr(collection, '_sa_iterator', + getattr(collection, '__iter__'))() + except AttributeError: + raise TypeError("'%s' object is not iterable" % + type(collection).__name__) + + class CollectionAdapter(object): """Bridges between the ORM and arbitrary Python collections. @@ -422,14 +466,12 @@ class CollectionAdapter(object): entity collections. """ - def __init__(self, attr, owner, data): + def __init__(self, attr, owner_state, data): self.attr = attr - self._owner = weakref.ref(owner) self._data = weakref.ref(data) + self.owner_state = owner_state self.link_to_self(data) - owner = property(lambda s: s._owner(), - doc="The object that owns the entity collection.") data = property(lambda s: s._data(), doc="The entity collection being adapted.") @@ -447,6 +489,46 @@ class CollectionAdapter(object): if hasattr(data, '_sa_on_link'): getattr(data, '_sa_on_link')(None) + def adapt_like_to_iterable(self, obj): + """Converts collection-compatible objects to an iterable of values. + + Can be passed any type of object, and if the underlying collection + determines that it can be adapted into a stream of values it can + use, returns an iterable of values suitable for append()ing. + + This method may raise TypeError or any other suitable exception + if adaptation fails. + + If a converter implementation is not supplied on the collection, + a default duck-typing-based implementation is used. + """ + + converter = getattr(self._data(), '_sa_converter', None) + if converter is not None: + return converter(obj) + + setting_type = sautil.duck_type_collection(obj) + receiving_type = sautil.duck_type_collection(self._data()) + + if obj is None or setting_type != receiving_type: + given = obj is None and 'None' or obj.__class__.__name__ + if receiving_type is None: + wanted = self._data().__class__.__name__ + else: + wanted = receiving_type.__name__ + + raise TypeError( + "Incompatible collection type: %s is not %s-like" % ( + given, wanted)) + + # If the object is an adapted collection, return the (iterable) adapter. + if getattr(obj, '_sa_adapter', None) is not None: + return getattr(obj, '_sa_adapter') + elif setting_type == dict: + return getattr(obj, 'itervalues', getattr(obj, 'values'))() + else: + return iter(obj) + def append_with_event(self, item, initiator=None): """Add an entity to the collection, firing mutation events.""" @@ -499,12 +581,12 @@ class CollectionAdapter(object): mutation, and should be left as None unless you are passing along an initiator value from a chained operation. """ - + if initiator is not False and item is not None: - self.attr.fire_append_event(self._owner(), item, initiator) + self.attr.fire_append_event(self.owner_state, item, initiator) def fire_remove_event(self, item, initiator=None): - """Notify that a entity has entered the collection. + """Notify that a entity has been removed from the collection. Initiator is the InstrumentedAttribute that initiated the membership mutation, and should be left as None unless you are passing along @@ -512,20 +594,66 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - self.attr.fire_remove_event(self._owner(), item, initiator) - + self.attr.fire_remove_event(self.owner_state, item, initiator) + + def fire_pre_remove_event(self, initiator=None): + """Notify that an entity is about to be removed from the collection. + + Only called if the entity cannot be removed after calling + fire_remove_event(). + """ + + self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator) + def __getstate__(self): return { 'key': self.attr.key, - 'owner': self.owner, + 'owner_state': self.owner_state, 'data': self.data } def __setstate__(self, d): - self.attr = getattr(d['owner'].__class__, d['key']) - self._owner = weakref.ref(d['owner']) + self.attr = getattr(d['owner_state'].obj().__class__, d['key']).impl + self.owner_state = d['owner_state'] self._data = weakref.ref(d['data']) -__instrumentation_mutex = Lock() +def bulk_replace(values, existing_adapter, new_adapter): + """Load a new collection, firing events based on prior like membership. + + Appends instances in ``values`` onto the ``new_adapter``. Events will be + fired for any instance not present in the ``existing_adapter``. Any + instances in ``existing_adapter`` not present in ``values`` will have + remove events fired upon them. + + values + An iterable of collection member instances + + existing_adapter + A CollectionAdapter of instances to be replaced + + new_adapter + An empty CollectionAdapter to load with ``values`` + + + """ + if not isinstance(values, list): + values = list(values) + + idset = sautil.IdentitySet + constants = idset(existing_adapter or ()).intersection(values or ()) + additions = idset(values or ()).difference(constants) + removals = idset(existing_adapter or ()).difference(constants) + + for member in values or (): + if member in additions: + new_adapter.append_with_event(member) + elif member in constants: + new_adapter.append_without_event(member) + + if existing_adapter: + for member in removals: + existing_adapter.remove_with_event(member) + +__instrumentation_mutex = sautil.threading.Lock() def _prepare_instrumentation(factory): """Prepare a callable for future use as a collection class factory. @@ -593,7 +721,7 @@ def _instrument_class(cls): # FIXME: more formally document this as a decoratorless/Python 2.3 # option for specifying instrumentation. (likely doc'd here in code only, # not in online docs.) - # + # # __instrumentation__ = { # 'rolename': 'methodname', # ... # 'methods': { @@ -612,7 +740,7 @@ def _instrument_class(cls): raise exceptions.ArgumentError( "Can not instrument a built-in type. Use a " "subclass, even a trivial one.") - + collection_type = sautil.duck_type_collection(cls) if collection_type in __interfaces: roles = __interfaces[collection_type].copy() @@ -626,14 +754,15 @@ def _instrument_class(cls): methods = roles.pop('methods', {}) for name in dir(cls): - method = getattr(cls, name) + method = getattr(cls, name, None) if not callable(method): continue # note role declarations if hasattr(method, '_sa_instrument_role'): role = method._sa_instrument_role - assert role in ('appender', 'remover', 'iterator', 'on_link') + assert role in ('appender', 'remover', 'iterator', + 'on_link', 'converter') roles[role] = name # transfer instrumentation requests from decorated function @@ -686,7 +815,7 @@ def _instrument_class(cls): for method, (before, argument, after) in methods.items(): setattr(cls, method, _instrument_membership_mutator(getattr(cls, method), - before, argument, after)) + before, argument, after)) # intern the role map for role, method in roles.items(): setattr(cls, '_sa_%s' % role, getattr(cls, method)) @@ -696,58 +825,52 @@ def _instrument_class(cls): def _instrument_membership_mutator(method, before, argument, after): """Route method args and/or return value through the collection adapter.""" - if type(argument) is int: - def wrapper(*args, **kw): - if before and len(args) < argument: - raise exceptions.ArgumentError( - 'Missing argument %i' % argument) - initiator = kw.pop('_sa_initiator', None) - if initiator is False: - executor = None + # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' + if before: + fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0])) + if type(argument) is int: + pos_arg = argument + named_arg = len(fn_args) > argument and fn_args[argument] or None + else: + if argument in fn_args: + pos_arg = fn_args.index(argument) else: - executor = getattr(args[0], '_sa_adapter', None) - - if before and executor: - getattr(executor, before)(args[argument], initiator) + pos_arg = None + named_arg = argument + del fn_args - if not after or not executor: - return method(*args, **kw) + def wrapper(*args, **kw): + if before: + if pos_arg is None: + if named_arg not in kw: + raise exceptions.ArgumentError( + "Missing argument %s" % argument) + value = kw[named_arg] else: - res = method(*args, **kw) - if res is not None: - getattr(executor, after)(res, initiator) - return res - else: - def wrapper(*args, **kw): - if before: - vals = inspect.getargvalues(inspect.currentframe()) - if argument in kw: - value = kw[argument] + if len(args) > pos_arg: + value = args[pos_arg] + elif named_arg in kw: + value = kw[named_arg] else: - positional = inspect.getargspec(method)[0] - pos = positional.index(argument) - if pos == -1: - raise exceptions.ArgumentError('Missing argument %s' % - argument) - else: - value = args[pos] + raise exceptions.ArgumentError( + "Missing argument %s" % argument) - initiator = kw.pop('_sa_initiator', None) - if initiator is False: - executor = None - else: - executor = getattr(args[0], '_sa_adapter', None) + initiator = kw.pop('_sa_initiator', None) + if initiator is False: + executor = None + else: + executor = getattr(args[0], '_sa_adapter', None) - if before and executor: - getattr(executor, before)(value, initiator) + if before and executor: + getattr(executor, before)(value, initiator) - if not after or not executor: - return method(*args, **kw) - else: - res = method(*args, **kw) - if res is not None: - getattr(executor, after)(res, initiator) - return res + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res try: wrapper._sa_instrumented = True wrapper.__name__ = method.__name__ @@ -763,7 +886,7 @@ def __set(collection, item, _sa_initiator=None): executor = getattr(collection, '_sa_adapter', None) if executor: getattr(executor, 'fire_append_event')(item, _sa_initiator) - + def __del(collection, item, _sa_initiator=None): """Run del events, may eventually be inlined into decorators.""" @@ -771,11 +894,18 @@ def __del(collection, item, _sa_initiator=None): executor = getattr(collection, '_sa_adapter', None) if executor: getattr(executor, 'fire_remove_event')(item, _sa_initiator) - + +def __before_delete(collection, _sa_initiator=None): + """Special method to run 'commit existing value' methods""" + + executor = getattr(collection, '_sa_adapter', None) + if executor: + getattr(executor, 'fire_pre_remove_event')(_sa_initiator) + def _list_decorators(): """Hand-turned instrumentation wrappers that can decorate any list-like class.""" - + def _tidy(fn): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__') @@ -787,7 +917,7 @@ def _list_decorators(): if _sa_initiator is not False and item is not None: executor = getattr(self, '_sa_adapter', None) if executor: - executor.attr.fire_append_event(executor._owner(), + executor.attr.fire_append_event(executor.owner_state, item, _sa_initiator) fn(self, item) _tidy(append) @@ -795,6 +925,8 @@ def _list_decorators(): def remove(fn): def remove(self, value, _sa_initiator=None): + __before_delete(self, _sa_initiator) + # testlib.pragma exempt:__eq__ fn(self, value) __del(self, value, _sa_initiator) _tidy(remove) @@ -868,7 +1000,7 @@ def _list_decorators(): fn(self, start, end, values) _tidy(__setslice__) return __setslice__ - + def __delslice__(fn): def __delslice__(self, start, end): for value in self[start:end]: @@ -883,15 +1015,31 @@ def _list_decorators(): self.append(value) _tidy(extend) return extend - + + def __iadd__(fn): + def __iadd__(self, iterable): + # list.__iadd__ takes any iterable and seems to let TypeError raise + # as-is instead of returning NotImplemented + for value in iterable: + self.append(value) + return self + _tidy(__iadd__) + return __iadd__ + def pop(fn): def pop(self, index=-1): + __before_delete(self) item = fn(self, index) __del(self, item) return item _tidy(pop) return pop + # __imul__ : not wrapping this. all members of the collection are already + # present, so no need to fire appends... wrapping it with an explicit + # decorator is still possible, so events on *= can be had if they're + # desired. hard to imagine a use case for __imul__, though. + l = locals().copy() l.pop('_tidy') return l @@ -904,7 +1052,7 @@ def _dict_decorators(): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__') - Unspecified=object() + Unspecified=sautil.symbol('Unspecified') def __setitem__(fn): def __setitem__(self, key, value, _sa_initiator=None): @@ -944,6 +1092,7 @@ def _dict_decorators(): def popitem(fn): def popitem(self): + __before_delete(self) item = fn(self) __del(self, item[1]) return item @@ -964,7 +1113,7 @@ def _dict_decorators(): def update(fn): def update(self, other): for key in other.keys(): - if not self.has_key(key) or self[key] is not other[key]: + if key not in self or self[key] is not other[key]: self[key] = other[key] _tidy(update) return update @@ -991,41 +1140,72 @@ def _dict_decorators(): l.pop('Unspecified') return l + +try: + _set_binop_bases = (set, frozenset, sets.BaseSet) +except NameError: + _set_binop_bases = (sets.BaseSet,) + +def _set_binops_check_strict(self, obj): + """Allow only set, frozenset and self.__class__-derived objects in binops.""" + return isinstance(obj, _set_binop_bases + (self.__class__,)) + +def _set_binops_check_loose(self, obj): + """Allow anything set-like to participate in set binops.""" + return (isinstance(obj, _set_binop_bases + (self.__class__,)) or + sautil.duck_type_collection(obj) == sautil.Set) + + def _set_decorators(): """Hand-turned instrumentation wrappers that can decorate any set-like sequence class.""" def _tidy(fn): setattr(fn, '_sa_instrumented', True) - fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__') + fn.__doc__ = getattr(getattr(Set, fn.__name__), '__doc__') - Unspecified=object() + Unspecified=sautil.symbol('Unspecified') def add(fn): def add(self, value, _sa_initiator=None): - __set(self, value, _sa_initiator) + if value not in self: + __set(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ fn(self, value) _tidy(add) return add - def discard(fn): - def discard(self, value, _sa_initiator=None): - if value in self: - __del(self, value, _sa_initiator) - fn(self, value) - _tidy(discard) - return discard + if sys.version_info < (2, 4): + def discard(fn): + def discard(self, value, _sa_initiator=None): + if value in self: + self.remove(value, _sa_initiator) + _tidy(discard) + return discard + else: + def discard(fn): + def discard(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + _tidy(discard) + return discard def remove(fn): def remove(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ if value in self: __del(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ fn(self, value) _tidy(remove) return remove def pop(fn): def pop(self): + __before_delete(self) item = fn(self) __del(self, item) return item @@ -1042,11 +1222,19 @@ def _set_decorators(): def update(fn): def update(self, value): for item in value: - if item not in self: - self.add(item) + self.add(item) _tidy(update) return update - __ior__ = update + + def __ior__(fn): + def __ior__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.add(item) + return self + _tidy(__ior__) + return __ior__ def difference_update(fn): def difference_update(self, value): @@ -1054,11 +1242,20 @@ def _set_decorators(): self.discard(item) _tidy(difference_update) return difference_update - __isub__ = difference_update + + def __isub__(fn): + def __isub__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.discard(item) + return self + _tidy(__isub__) + return __isub__ def intersection_update(fn): def intersection_update(self, other): - want, have = self.intersection(other), sautil.Set(self) + want, have = self.intersection(other), Set(self) remove, add = have - want, want - have for item in remove: @@ -1067,11 +1264,25 @@ def _set_decorators(): self.add(item) _tidy(intersection_update) return intersection_update - __iand__ = intersection_update + + def __iand__(fn): + def __iand__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), Set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + _tidy(__iand__) + return __iand__ def symmetric_difference_update(fn): def symmetric_difference_update(self, other): - want, have = self.symmetric_difference(other), sautil.Set(self) + want, have = self.symmetric_difference(other), Set(self) remove, add = have - want, want - have for item in remove: @@ -1080,7 +1291,21 @@ def _set_decorators(): self.add(item) _tidy(symmetric_difference_update) return symmetric_difference_update - __ixor__ = symmetric_difference_update + + def __ixor__(fn): + def __ixor__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), Set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + _tidy(__ixor__) + return __ixor__ l = locals().copy() l.pop('_tidy') @@ -1096,7 +1321,7 @@ class InstrumentedList(list): 'remover': 'remove', 'iterator': '__iter__', } -class InstrumentedSet(sautil.Set): +class InstrumentedSet(Set): """An instrumented version of the built-in set (or Set).""" __instrumentation__ = { @@ -1104,7 +1329,7 @@ class InstrumentedSet(sautil.Set): 'remover': 'remove', 'iterator': '__iter__', } -class InstrumentedDict(dict): +class InstrumentedDict(dict): """An instrumented version of the built-in dict.""" __instrumentation__ = { @@ -1112,7 +1337,7 @@ class InstrumentedDict(dict): __canned_instrumentation = { list: InstrumentedList, - sautil.Set: InstrumentedSet, + Set: InstrumentedSet, dict: InstrumentedDict, } @@ -1121,10 +1346,10 @@ __interfaces = { 'remover': 'remove', 'iterator': '__iter__', '_decorators': _list_decorators(), }, - sautil.Set: { 'appender': 'add', - 'remover': 'remove', - 'iterator': '__iter__', - '_decorators': _set_decorators(), }, + Set: { 'appender': 'add', + 'remover': 'remove', + 'iterator': '__iter__', + '_decorators': _set_decorators(), }, # decorators are required for dicts and object collections. dict: { 'iterator': 'itervalues', '_decorators': _dict_decorators(), }, @@ -1141,7 +1366,7 @@ class MappedCollection(dict): callable that takes an object and returns an object for use as a dictionary key. """ - + def __init__(self, keyfunc): """Create a new collection with keying provided by keyfunc. @@ -1164,12 +1389,13 @@ class MappedCollection(dict): self.__setitem__(key, value, _sa_initiator) set = collection.internally_instrumented(set) set = collection.appender(set) - + def remove(self, value, _sa_initiator=None): """Remove an item from the collection by value, consulting this instance's keyfunc for the key.""" - + key = self.keyfunc(value) # Let self[key] raise if key is not in this collection + # testlib.pragma exempt:__ne__ if self[key] != value: raise exceptions.InvalidRequestError( "Can not remove '%s': collection holds '%s' for key '%s'. " @@ -1180,3 +1406,26 @@ class MappedCollection(dict): self.__delitem__(key, _sa_initiator) remove = collection.internally_instrumented(remove) remove = collection.remover(remove) + + def _convert(self, dictlike): + """Validate and convert a dict-like object into values for set()ing. + + This is called behind the scenes when a MappedCollection is replaced + entirely by another collection, as in:: + + myobj.mappedcollection = {'a':obj1, 'b': obj2} # ... + + Raises a TypeError if the key in any (key, value) pair in the dictlike + object does not match the key that this collection's keyfunc would + have assigned for that value. + """ + + for incoming_key, value in sautil.dictlike_iteritems(dictlike): + new_key = self.keyfunc(value) + if incoming_key != new_key: + raise TypeError( + "Found incompatible key %r for value %r; this collection's " + "keying function requires a key of %r for this value." % ( + incoming_key, value, new_key)) + yield value + _convert = collection.converter(_convert) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index c06db69631..c667460a71 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -1,5 +1,5 @@ # orm/dependency.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -11,9 +11,9 @@ """ from sqlalchemy.orm import sync -from sqlalchemy.orm.sync import ONETOMANY,MANYTOONE,MANYTOMANY from sqlalchemy import sql, util, exceptions -from sqlalchemy.orm import session as sessionlib +from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY + def create_dependency_processor(prop): types = { @@ -27,22 +27,24 @@ def create_dependency_processor(prop): return types[prop.direction](prop) class DependencyProcessor(object): + no_dependencies = False + def __init__(self, prop): self.prop = prop self.cascade = prop.cascade self.mapper = prop.mapper self.parent = prop.parent - self.association = prop.association self.secondary = prop.secondary self.direction = prop.direction self.is_backref = prop.is_backref self.post_update = prop.post_update self.foreign_keys = prop.foreign_keys self.passive_deletes = prop.passive_deletes + self.passive_updates = prop.passive_updates self.enable_typechecks = prop.enable_typechecks self.key = prop.key - - self._compile_synchronizers() + if not self.prop.synchronize_pairs: + raise exceptions.ArgumentError("Can't build a DependencyProcessor for relation %s. No target attributes to populate between parent and child are present" % self.prop) def _get_instrumented_attribute(self): """Return the ``InstrumentedAttribute`` handled by this @@ -51,6 +53,13 @@ class DependencyProcessor(object): return getattr(self.parent.class_, self.key) + def hasparent(self, state): + """return True if the given object instance has a parent, + according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``.""" + + # TODO: use correct API for this + return self._get_instrumented_attribute().impl.hasparent(state) + def register_dependencies(self, uowcommit): """Tell a ``UOWTransaction`` what mappers are dependent on which, with regards to the two or three mappers handled by @@ -65,21 +74,18 @@ class DependencyProcessor(object): raise NotImplementedError() - def whose_dependent_on_who(self, obj1, obj2): + def whose_dependent_on_who(self, state1, state2): """Given an object pair assuming `obj2` is a child of `obj1`, return a tuple with the dependent object second, or None if - they are equal. - - Used by objectstore's object-level topological sort (i.e. cyclical - table dependency). + there is no dependency. """ - if obj1 is obj2: + if state1 is state2: return None elif self.direction == ONETOMANY: - return (obj1, obj2) + return (state1, state2) else: - return (obj2, obj1) + return (state2, state1) def process_dependencies(self, task, deplist, uowcommit, delete = False): """This method is called during a flush operation to @@ -101,13 +107,13 @@ class DependencyProcessor(object): raise NotImplementedError() - def _verify_canload(self, child): + def _verify_canload(self, state): if not self.enable_typechecks: return - if child is not None and not self.mapper.canload(child): - raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (child.__class__, self.prop, self.mapper)) - - def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): + if state is not None and not self.mapper._canload(state): + raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper)) + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): """Called during a flush to synchronize primary key identifier values between a parent/child object, as well as to an associationrow in the case of many-to-many. @@ -115,32 +121,8 @@ class DependencyProcessor(object): raise NotImplementedError() - def _compile_synchronizers(self): - """Assemble a list of *synchronization rules*, which are - instructions on how to populate the objects on each side of a - relationship. This is done when a ``DependencyProcessor`` is - first initialized. - - The list of rules is used within commits by the ``_synchronize()`` - method when dependent objects are processed. - """ - - self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction) - if self.direction == sync.MANYTOMANY: - self.syncrules.compile(self.prop.primaryjoin, issecondary=False, foreign_keys=self.foreign_keys) - self.syncrules.compile(self.prop.secondaryjoin, issecondary=True, foreign_keys=self.foreign_keys) - else: - self.syncrules.compile(self.prop.primaryjoin, foreign_keys=self.foreign_keys) - - def get_object_dependencies(self, obj, uowcommit, passive = True): - """Return the list of objects that are dependent on the given - object, as according to the relationship this dependency - processor represents. - """ - return sessionlib.attribute_manager.get_history(obj, self.key, passive = passive) - - def _conditional_post_update(self, obj, uowcommit, related): + def _conditional_post_update(self, state, uowcommit, related): """Execute a post_update call. For relations that contain the post_update flag, an additional @@ -154,12 +136,15 @@ class DependencyProcessor(object): given related object list contains ``INSERT``s or ``DELETE``s. """ - if obj is not None and self.post_update: + if state is not None and self.post_update: for x in related: if x is not None: - uowcommit.register_object(obj, postupdate=True, post_update_cols=self.syncrules.dest_columns()) + uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs]) break + def _pks_changed(self, uowcommit, state): + raise NotImplementedError() + def __str__(self): return "%s(%s)" % (self.__class__.__name__, str(self.prop)) @@ -182,28 +167,35 @@ class OneToManyDP(DependencyProcessor): # the child objects have to have their foreign key to the parent set to NULL # this phase can be called safely for any cascade but is unnecessary if delete cascade # is on. - if not self.cascade.delete or self.post_update: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items(): - if child is not None and childlist.hasparent(child) is False: - self._synchronize(obj, child, None, True, uowcommit) - self._conditional_post_update(child, uowcommit, [obj]) - for child in childlist.unchanged_items(): - if child is not None: - self._synchronize(obj, child, None, True, uowcommit) - self._conditional_post_update(child, uowcommit, [obj]) + if self.post_update or not self.passive_deletes=='all': + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + if unchanged or deleted: + for child in deleted: + if child is not None and self.hasparent(child) is False: + self._synchronize(state, child, None, True, uowcommit) + self._conditional_post_update(child, uowcommit, [state]) + if self.post_update or not self.cascade.delete: + for child in unchanged: + if child is not None: + self._synchronize(state, child, None, True, uowcommit) + self._conditional_post_update(child, uowcommit, [state]) else: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=True) - if childlist is not None: - for child in childlist.added_items(): - self._synchronize(obj, child, None, False, uowcommit) - self._conditional_post_update(child, uowcommit, [obj]) - for child in childlist.deleted_items(): - if not self.cascade.delete_orphan and not self._get_instrumented_attribute().hasparent(child): - self._synchronize(obj, child, None, True, uowcommit) + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True) + if added or deleted: + for child in added: + self._synchronize(state, child, None, False, uowcommit) + if child is not None: + self._conditional_post_update(child, uowcommit, [state]) + for child in deleted: + if not self.cascade.delete_orphan and not self.hasparent(child): + self._synchronize(state, child, None, True, uowcommit) + + if self._pks_changed(uowcommit, state): + if unchanged: + for child in unchanged: + self._synchronize(state, child, None, False, uowcommit) def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction) @@ -211,40 +203,104 @@ class OneToManyDP(DependencyProcessor): if delete: # head object is being deleted, and we manage its list of child objects # the child objects have to have their foreign key to the parent set to NULL - if not self.post_update and not self.cascade.delete: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items(): - if child is not None and childlist.hasparent(child) is False: - uowcommit.register_object(child) - for child in childlist.unchanged_items(): - if child is not None: - uowcommit.register_object(child) + if not self.post_update: + should_null_fks = not self.cascade.delete and not self.passive_deletes=='all' + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + if unchanged or deleted: + for child in deleted: + if child is not None and self.hasparent(child) is False: + if self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=True) + else: + uowcommit.register_object(child) + if should_null_fks: + for child in unchanged: + if child is not None: + uowcommit.register_object(child) else: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=True) - if childlist is not None: - for child in childlist.added_items(): + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True) + if added or deleted: + for child in added: if child is not None: uowcommit.register_object(child) - for child in childlist.deleted_items(): + for child in deleted: if not self.cascade.delete_orphan: uowcommit.register_object(child, isdelete=False) - elif childlist.hasparent(child) is False: + elif self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c, isdelete=True) + for c, m in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c._state, isdelete=True) + if not self.passive_updates and self._pks_changed(uowcommit, state): + if not unchanged: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=False) + if unchanged: + for child in unchanged: + uowcommit.register_object(child) - def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): - source = obj + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + source = state dest = child if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): return self._verify_canload(child) - self.syncrules.execute(source, dest, obj, child, clearkeys) + if clearkeys: + sync.clear(dest, self.mapper, self.prop.synchronize_pairs) + else: + sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs) + + def _pks_changed(self, uowcommit, state): + return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs) + +class DetectKeySwitch(DependencyProcessor): + """a special DP that works for many-to-one relations, fires off for + child items who have changed their referenced key.""" + + no_dependencies = True + + def register_dependencies(self, uowcommit): + uowcommit.register_processor(self.parent, self, self.mapper) + + def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): + # for non-passive updates, register in the preprocess stage + # so that mapper save_obj() gets a hold of changes + if not delete and not self.passive_updates: + self._process_key_switches(deplist, uowcommit) + + def process_dependencies(self, task, deplist, uowcommit, delete=False): + # for passive updates, register objects in the process stage + # so that we avoid ManyToOneDP's registering the object without + # the listonly flag in its own preprocess stage (results in UPDATE) + # statements being emitted + if not delete and self.passive_updates: + self._process_key_switches(deplist, uowcommit) + + def _process_key_switches(self, deplist, uowcommit): + switchers = util.Set([s for s in deplist if self._pks_changed(uowcommit, s)]) + if switchers: + # yes, we're doing a linear search right now through the UOW. only + # takes effect when primary key values have actually changed. + # a possible optimization might be to enhance the "hasparents" capability of + # attributes to actually store all parent references, but this introduces + # more complicated attribute accounting. + for s in [elem for elem in uowcommit.session.identity_map.all_states() + if issubclass(elem.class_, self.parent.class_) and + self.key in elem.dict and + elem.dict[self.key]._state in switchers + ]: + uowcommit.register_object(s, listonly=self.passive_updates) + sync.populate(s.dict[self.key]._state, self.mapper, s, self.parent, self.prop.synchronize_pairs) + #self.syncrules.execute(s.dict[self.key]._state, s, None, None, False) + + def _pks_changed(self, uowcommit, state): + return sync.source_changes(uowcommit, state, self.mapper, self.prop.synchronize_pairs) class ManyToOneDP(DependencyProcessor): + def __init__(self, prop): + DependencyProcessor.__init__(self, prop) + self.mapper._dependency_processors.append(DetectKeySwitch(prop)) + def register_dependencies(self, uowcommit): if self.post_update: if not self.is_backref: @@ -256,58 +312,66 @@ class ManyToOneDP(DependencyProcessor): uowcommit.register_dependency(self.mapper, self.parent) uowcommit.register_processor(self.mapper, self, self.parent) + def process_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) if delete: - if self.post_update and not self.cascade.delete_orphan: + if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes=='all': # post_update means we have to update our row to not reference the child object # before we can DELETE the row - for obj in deplist: - self._synchronize(obj, None, None, True, uowcommit) - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - self._conditional_post_update(obj, uowcommit, childlist.deleted_items() + childlist.unchanged_items() + childlist.added_items()) + for state in deplist: + self._synchronize(state, None, None, True, uowcommit) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + if added or unchanged or deleted: + self._conditional_post_update(state, uowcommit, deleted + unchanged + added) else: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=True) - if childlist is not None: - for child in childlist.added_items(): - self._synchronize(obj, child, None, False, uowcommit) - self._conditional_post_update(obj, uowcommit, childlist.deleted_items() + childlist.unchanged_items() + childlist.added_items()) + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True) + if added or deleted or unchanged: + for child in added: + self._synchronize(state, child, None, False, uowcommit) + self._conditional_post_update(state, uowcommit, deleted + unchanged + added) def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " PRE process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) if self.post_update: return if delete: - if self.cascade.delete: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items() + childlist.unchanged_items(): - if child is not None and childlist.hasparent(child) is False: - uowcommit.register_object(child, isdelete=True) - for c in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c, isdelete=True) + if self.cascade.delete or self.cascade.delete_orphan: + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + if self.cascade.delete_orphan: + todelete = added + unchanged + deleted + else: + todelete = added + unchanged + for child in todelete: + if child is None: + continue + uowcommit.register_object(child, isdelete=True) + for c, m in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c._state, isdelete=True) else: - for obj in deplist: - uowcommit.register_object(obj) + for state in deplist: + uowcommit.register_object(state) if self.cascade.delete_orphan: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items(): - if childlist.hasparent(child) is False: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + if deleted: + for child in deleted: + if self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c, isdelete=True) + for c, m in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c._state, isdelete=True) - def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): - source = child - dest = obj - if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + if state is None or (not self.post_update and uowcommit.is_deleted(state)): return - self._verify_canload(child) - self.syncrules.execute(source, dest, obj, child, clearkeys) + + if clearkeys or child is None: + sync.clear(state, self.parent, self.prop.synchronize_pairs) + else: + self._verify_canload(child) + sync.populate(child, self.mapper, state, self.parent, self.prop.synchronize_pairs) class ManyToManyDP(DependencyProcessor): def register_dependencies(self, uowcommit): @@ -327,73 +391,92 @@ class ManyToManyDP(DependencyProcessor): connection = uowcommit.transaction.connection(self.mapper) secondary_delete = [] secondary_insert = [] + secondary_update = [] - if hasattr(self.prop, 'reverse_property'): - reverse_dep = getattr(self.prop.reverse_property, '_dependency_processor', None) + if self.prop._reverse_property: + reverse_dep = getattr(self.prop._reverse_property, '_dependency_processor', None) else: reverse_dep = None - + if delete: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items() + childlist.unchanged_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes): + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + if deleted or unchanged: + for child in deleted + unchanged: + if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes): continue associationrow = {} - self._synchronize(obj, child, associationrow, False, uowcommit) + self._synchronize(state, child, associationrow, False, uowcommit) secondary_delete.append(associationrow) - uowcommit.attributes[(self, "manytomany", obj, child)] = True + uowcommit.attributes[(self, "manytomany", state, child)] = True else: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit) - if childlist is None: continue - for child in childlist.added_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes): - continue - associationrow = {} - self._synchronize(obj, child, associationrow, False, uowcommit) - uowcommit.attributes[(self, "manytomany", obj, child)] = True - secondary_insert.append(associationrow) - for child in childlist.deleted_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes): - continue - associationrow = {} - self._synchronize(obj, child, associationrow, False, uowcommit) - uowcommit.attributes[(self, "manytomany", obj, child)] = True - secondary_delete.append(associationrow) - - if len(secondary_delete): + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key) + if added or deleted: + for child in added: + if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes): + continue + associationrow = {} + self._synchronize(state, child, associationrow, False, uowcommit) + uowcommit.attributes[(self, "manytomany", state, child)] = True + secondary_insert.append(associationrow) + for child in deleted: + if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes): + continue + associationrow = {} + self._synchronize(state, child, associationrow, False, uowcommit) + uowcommit.attributes[(self, "manytomany", state, child)] = True + secondary_delete.append(associationrow) + + if not self.passive_updates and unchanged and self._pks_changed(uowcommit, state): + for child in unchanged: + associationrow = {} + sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs) + sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs) + + #self.syncrules.update(associationrow, state, child, "old_") + secondary_update.append(associationrow) + + if secondary_delete: secondary_delete.sort() # TODO: precompile the delete/insert queries? statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow])) result = connection.execute(statement, secondary_delete) - if result.supports_sane_rowcount() and result.rowcount != len(secondary_delete): - raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (result.rowcount, len(secondary_delete))) + if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete): + raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of secondary table rows deleted from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_delete))) + + if secondary_update: + statement = self.secondary.update(sql.and_(*[c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow])) + result = connection.execute(statement, secondary_update) + if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update): + raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of secondary table rows updated from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_update))) - if len(secondary_insert): + if secondary_insert: statement = self.secondary.insert() connection.execute(statement, secondary_insert) def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction) if not delete: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=True) - if childlist is not None: - for child in childlist.deleted_items(): - if self.cascade.delete_orphan and childlist.hasparent(child) is False: + for state in deplist: + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True) + if deleted: + for child in deleted: + if self.cascade.delete_orphan and self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c, isdelete=True) + for c, m in self.mapper.cascade_iterator('delete', child): + uowcommit.register_object(c._state, isdelete=True) - def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): - dest = associationrow - source = None - if dest is None: + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + if associationrow is None: return self._verify_canload(child) - self.syncrules.execute(source, dest, obj, child, clearkeys) + + sync.populate_dict(state, self.parent, associationrow, self.prop.synchronize_pairs) + sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs) + + def _pks_changed(self, uowcommit, state): + return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs) class AssociationDP(OneToManyDP): def __init__(self, *args, **kwargs): @@ -406,7 +489,7 @@ class MapperStub(object): many-to-many join, when performing a ``flush()``. The ``Task`` objects in the objectstore module treat it just like - any other ``Mapper``, but in fact it only serves as a *dependency* + any other ``Mapper``, but in fact it only serves as a dependency placeholder for the many-to-many update task. """ @@ -414,23 +497,21 @@ class MapperStub(object): def __init__(self, parent, mapper, key): self.mapper = mapper + self.base_mapper = self self.class_ = mapper.class_ self._inheriting_mappers = [] def polymorphic_iterator(self): return iter([self]) - - def register_dependencies(self, uowcommit): + + def _register_dependencies(self, uowcommit): pass - def save_obj(self, *args, **kwargs): + def _save_obj(self, *args, **kwargs): pass - def delete_obj(self, *args, **kwargs): + def _delete_obj(self, *args, **kwargs): pass def primary_mapper(self): return self - - def base_mapper(self): - return self diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py new file mode 100644 index 0000000000..133ad99c89 --- /dev/null +++ b/lib/sqlalchemy/orm/dynamic.py @@ -0,0 +1,190 @@ +"""'dynamic' collection API. returns Query() objects on the 'read' side, alters +a special AttributeHistory on the 'write' side.""" + +from sqlalchemy import exceptions, util, logging +from sqlalchemy.orm import attributes, object_session, util as mapperutil, strategies +from sqlalchemy.orm.query import Query +from sqlalchemy.orm.mapper import has_identity, object_mapper + + +class DynaLoader(strategies.AbstractRelationLoader): + def init_class_attribute(self): + self.is_class_level = True + self._register_attribute(self.parent.class_, impl_class=DynamicAttributeImpl, target_mapper=self.parent_property.mapper, order_by=self.parent_property.order_by) + + def create_row_processor(self, selectcontext, mapper, row): + return (None, None, None) + +DynaLoader.logger = logging.class_logger(DynaLoader) + +class DynamicAttributeImpl(attributes.AttributeImpl): + def __init__(self, class_, key, typecallable, target_mapper, order_by, **kwargs): + super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs) + self.target_mapper = target_mapper + self.order_by=order_by + self.query_class = AppenderQuery + + def get(self, state, passive=False): + if passive: + return self._get_collection_history(state, passive=True).added_items + else: + return self.query_class(self, state) + + def get_collection(self, state, user_data=None, passive=True): + if passive: + return self._get_collection_history(state, passive=passive).added_items + else: + history = self._get_collection_history(state, passive=passive) + return history.added_items + history.unchanged_items + + def fire_append_event(self, state, value, initiator): + state.modified = True + + if self.trackparent and value is not None: + self.sethasparent(value._state, True) + instance = state.obj() + for ext in self.extensions: + ext.append(instance, value, initiator or self) + + def fire_remove_event(self, state, value, initiator): + state.modified = True + + if self.trackparent and value is not None: + self.sethasparent(value._state, False) + + instance = state.obj() + for ext in self.extensions: + ext.remove(instance, value, initiator or self) + + def set(self, state, value, initiator): + if initiator is self: + return + + old_collection = self.get(state).assign(value) + + # TODO: emit events ??? + state.modified = True + + def delete(self, *args, **kwargs): + raise NotImplementedError() + + def get_history(self, state, passive=False): + c = self._get_collection_history(state, passive) + return (c.added_items, c.unchanged_items, c.deleted_items) + + def _get_collection_history(self, state, passive=False): + try: + c = state.dict[self.key] + except KeyError: + state.dict[self.key] = c = CollectionHistory(self, state) + + if not passive: + return CollectionHistory(self, state, apply_to=c) + else: + return c + + def append(self, state, value, initiator, passive=False): + if initiator is not self: + self._get_collection_history(state, passive=True).added_items.append(value) + self.fire_append_event(state, value, initiator) + + def remove(self, state, value, initiator, passive=False): + if initiator is not self: + self._get_collection_history(state, passive=True).deleted_items.append(value) + self.fire_remove_event(state, value, initiator) + + +class AppenderQuery(Query): + def __init__(self, attr, state): + super(AppenderQuery, self).__init__(attr.target_mapper, None) + self.instance = state.obj() + self.attr = attr + + def __session(self): + sess = object_session(self.instance) + if sess is not None and self.autoflush and sess.autoflush and self.instance in sess: + sess.flush() + if not has_identity(self.instance): + return None + else: + return sess + + def session(self): + return self.__session() + session = property(session) + + def __iter__(self): + sess = self.__session() + if sess is None: + return iter(self.attr._get_collection_history(self.instance._state, passive=True).added_items) + else: + return iter(self._clone(sess)) + + def __getitem__(self, index): + sess = self.__session() + if sess is None: + return self.attr._get_collection_history(self.instance._state, passive=True).added_items.__getitem__(index) + else: + return self._clone(sess).__getitem__(index) + + def count(self): + sess = self.__session() + if sess is None: + return len(self.attr._get_collection_history(self.instance._state, passive=True).added_items) + else: + return self._clone(sess).count() + + def _clone(self, sess=None): + # note we're returning an entirely new Query class instance here + # without any assignment capabilities; + # the class of this query is determined by the session. + instance = self.instance + if sess is None: + sess = object_session(instance) + if sess is None: + try: + sess = object_mapper(instance).get_session() + except exceptions.InvalidRequestError: + raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.instance_str(instance), self.attr.key)) + + q = sess.query(self.attr.target_mapper).with_parent(instance, self.attr.key) + if self.attr.order_by: + q = q.order_by(self.attr.order_by) + return q + + def assign(self, collection): + instance = self.instance + if has_identity(instance): + oldlist = list(self) + else: + oldlist = [] + self.attr._get_collection_history(self.instance._state, passive=True).replace(oldlist, collection) + return oldlist + + def append(self, item): + self.attr.append(self.instance._state, item, None) + + def remove(self, item): + self.attr.remove(self.instance._state, item, None) + + +class CollectionHistory(object): + """Overrides AttributeHistory to receive append/remove events directly.""" + + def __init__(self, attr, state, apply_to=None): + if apply_to: + deleted = util.IdentitySet(apply_to.deleted_items) + added = apply_to.added_items + coll = AppenderQuery(attr, state).autoflush(False) + self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted] + self.added_items = apply_to.added_items + self.deleted_items = apply_to.deleted_items + else: + self.deleted_items = [] + self.added_items = [] + self.unchanged_items = [] + + def replace(self, olditems, newitems): + self.added_items = newitems + self.deleted_items = olditems + diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index aeb8a23fa1..d61ebe9603 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -1,28 +1,61 @@ # interfaces.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Semi-private implementation objects which form the basis +of ORM-mapped attributes, query options and mapper extension. -from sqlalchemy import util, logging, sql +Defines the [sqlalchemy.orm.interfaces#MapperExtension] class, +which can be end-user subclassed to add event-based functionality +to mappers. The remainder of this module is generally private to the +ORM. +""" -# returned by a MapperExtension method to indicate a "do nothing" response -EXT_PASS = object() +from itertools import chain +from sqlalchemy import exceptions, logging, util +from sqlalchemy.sql import expression +class_mapper = None + +__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension', + 'MapperProperty', 'PropComparator', 'StrategizedProperty', + 'build_path', 'MapperOption', + 'ExtensionOption', 'PropertyOption', + 'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ] + +EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE') +EXT_STOP = util.symbol('EXT_STOP') + +ONETOMANY = util.symbol('ONETOMANY') +MANYTOONE = util.symbol('MANYTOONE') +MANYTOMANY = util.symbol('MANYTOMANY') class MapperExtension(object): - """Base implementation for an object that provides overriding - behavior to various Mapper functions. For each method in - MapperExtension, a result of EXT_PASS indicates the functionality - is not overridden. + """Base implementation for customizing Mapper behavior. + + For each method in MapperExtension, returning a result of EXT_CONTINUE + will allow processing to continue to the next MapperExtension in line or + use the default functionality if there are no other extensions. + + Returning EXT_STOP will halt processing of further extensions handling + that method. Some methods such as ``load`` have other return + requirements, see the individual documentation for details. Other than + these exception cases, any return value other than EXT_CONTINUE or + EXT_STOP will be interpreted as equivalent to EXT_STOP. + + EXT_PASS is a synonym for EXT_CONTINUE and is provided for backward + compatibility. """ + def instrument_class(self, mapper, class_): + return EXT_CONTINUE - def init_instance(self, mapper, class_, instance, args, kwargs): - return EXT_PASS + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): + return EXT_CONTINUE - def init_failed(self, mapper, class_, instance, args, kwargs): - return EXT_PASS + def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): + return EXT_CONTINUE def get_session(self): """Retrieve a contextual Session instance with which to @@ -32,61 +65,61 @@ class MapperExtension(object): `__init__` params (i.e. `_sa_session`). """ - return EXT_PASS + return EXT_CONTINUE def load(self, query, *args, **kwargs): """Override the `load` method of the Query object. The return value of this method is used as the result of - ``query.load()`` if the value is anything other than EXT_PASS. + ``query.load()`` if the value is anything other than EXT_CONTINUE. """ - return EXT_PASS + return EXT_CONTINUE def get(self, query, *args, **kwargs): """Override the `get` method of the Query object. The return value of this method is used as the result of - ``query.get()`` if the value is anything other than EXT_PASS. + ``query.get()`` if the value is anything other than EXT_CONTINUE. """ - return EXT_PASS + return EXT_CONTINUE def get_by(self, query, *args, **kwargs): """Override the `get_by` method of the Query object. The return value of this method is used as the result of ``query.get_by()`` if the value is anything other than - EXT_PASS. - + EXT_CONTINUE. + DEPRECATED. """ - return EXT_PASS + return EXT_CONTINUE def select_by(self, query, *args, **kwargs): """Override the `select_by` method of the Query object. The return value of this method is used as the result of ``query.select_by()`` if the value is anything other than - EXT_PASS. - + EXT_CONTINUE. + DEPRECATED. """ - return EXT_PASS + return EXT_CONTINUE def select(self, query, *args, **kwargs): """Override the `select` method of the Query object. The return value of this method is used as the result of ``query.select()`` if the value is anything other than - EXT_PASS. - + EXT_CONTINUE. + DEPRECATED. """ - return EXT_PASS + return EXT_CONTINUE def translate_row(self, mapper, context, row): @@ -97,15 +130,14 @@ class MapperExtension(object): method. """ - return EXT_PASS + return EXT_CONTINUE def create_instance(self, mapper, selectcontext, row, class_): """Receive a row when a new object instance is about to be created from that row. - The method can choose to create the instance itself, or it can - return None to indicate normal object creation should take - place. + The method can choose to create the instance itself, or it can return + EXT_CONTINUE to indicate normal object creation should take place. mapper The mapper doing the operation @@ -118,15 +150,18 @@ class MapperExtension(object): class\_ The class we are mapping. + + return value + A new object instance, or EXT_CONTINUE """ - return EXT_PASS + return EXT_CONTINUE def append_result(self, mapper, selectcontext, row, instance, result, **flags): """Receive an object instance before that instance is appended to a result list. - If this method returns EXT_PASS, result appending will proceed + If this method returns EXT_CONTINUE, result appending will proceed normally. if this method returns any other value or None, result appending will not proceed for this instance, giving this extension an opportunity to do the appending itself, if @@ -152,7 +187,7 @@ class MapperExtension(object): `create_row_processor()` method of [sqlalchemy.orm.interfaces#MapperProperty] """ - return EXT_PASS + return EXT_CONTINUE def populate_instance(self, mapper, selectcontext, row, instance, **flags): """Receive a newly-created instance before that instance has @@ -161,14 +196,14 @@ class MapperExtension(object): The normal population of attributes is according to each attribute's corresponding MapperProperty (which includes column-based attributes as well as relationships to other - classes). If this method returns EXT_PASS, instance + classes). If this method returns EXT_CONTINUE, instance population will proceed normally. If any other value or None is returned, instance population will not proceed, giving this extension an opportunity to populate the instance itself, if desired. """ - return EXT_PASS + return EXT_CONTINUE def before_insert(self, mapper, connection, instance): """Receive an object instance before that instance is INSERTed @@ -176,34 +211,70 @@ class MapperExtension(object): This is a good place to set up primary key values and such that aren't handled otherwise. + + Column-based attributes can be modified within this method which will + result in the new value being inserted. However *no* changes to the overall + flush plan can be made; this means any collection modification or + save() operations which occur within this method will not take effect + until the next flush call. + """ - return EXT_PASS + return EXT_CONTINUE + + def after_insert(self, mapper, connection, instance): + """Receive an object instance after that instance is INSERTed.""" + + return EXT_CONTINUE def before_update(self, mapper, connection, instance): - """Receive an object instance before that instance is UPDATEed.""" + """Receive an object instance before that instance is UPDATEed. + + Note that this method is called for all instances that are marked as + "dirty", even those which have no net changes to their column-based + attributes. An object is marked as dirty when any of its column-based + attributes have a "set attribute" operation called or when any of its + collections are modified. If, at update time, no column-based attributes + have any net changes, no UPDATE statement will be issued. This means + that an instance being sent to before_update is *not* a guarantee that + an UPDATE statement will be issued (although you can affect the outcome + here). + + To detect if the column-based attributes on the object have net changes, + and will therefore generate an UPDATE statement, use + ``object_session(instance).is_modified(instance, include_collections=False)``. + + Column-based attributes can be modified within this method which will + result in their being updated. However *no* changes to the overall + flush plan can be made; this means any collection modification or + save() operations which occur within this method will not take effect + until the next flush call. + + """ - return EXT_PASS + return EXT_CONTINUE def after_update(self, mapper, connection, instance): """Receive an object instance after that instance is UPDATEed.""" - return EXT_PASS + return EXT_CONTINUE - def after_insert(self, mapper, connection, instance): - """Receive an object instance after that instance is INSERTed.""" + def before_delete(self, mapper, connection, instance): + """Receive an object instance before that instance is DELETEed. - return EXT_PASS + Note that *no* changes to the overall + flush plan can be made here; this means any collection modification, + save() or delete() operations which occur within this method will + not take effect until the next flush call. - def before_delete(self, mapper, connection, instance): - """Receive an object instance before that instance is DELETEed.""" + """ - return EXT_PASS + return EXT_CONTINUE def after_delete(self, mapper, connection, instance): """Receive an object instance after that instance is DELETEed.""" - return EXT_PASS + return EXT_CONTINUE class MapperProperty(object): """Manage the relationship of a ``Mapper`` to a single class @@ -214,7 +285,7 @@ class MapperProperty(object): def setup(self, querycontext, **kwargs): """Called by Query for the purposes of constructing a SQL statement. - + Each MapperProperty associated with the target mapper processes the statement referenced by the query context, adding columns and/or criterion as appropriate. @@ -223,52 +294,58 @@ class MapperProperty(object): pass def create_row_processor(self, selectcontext, mapper, row): - """return a 2-tuple consiting of a row processing function and an instance post-processing function. - + """Return a 3-tuple consiting of two row processing functions and an instance post-processing function. + Input arguments are the query.SelectionContext and the *first* - applicable row of a result set obtained within query.Query.instances(), called - only the first time a particular mapper.populate_instance() is invoked for the - overal result. - - The settings contained within the SelectionContext as well as the columns present - in the row (which will be the same columns present in all rows) are used to determine - the behavior of the returned callables. The callables will then be used to process - all rows and to post-process all instances, respectively. - - callables are of the following form:: - - def execute(instance, row, **flags): - # process incoming instance and given row. - # flags is a dictionary containing at least the following attributes: - # isnew - indicates if the instance was newly created as a result of reading this row + applicable row of a result set obtained within + query.Query.instances(), called only the first time a particular + mapper's populate_instance() method is invoked for the overall result. + + The settings contained within the SelectionContext as well as the + columns present in the row (which will be the same columns present in + all rows) are used to determine the presence and behavior of the + returned callables. The callables will then be used to process all + rows and to post-process all instances, respectively. + + Callables are of the following form:: + + def new_execute(instance, row, **flags): + # process incoming instance and given row. the instance is + # "new" and was just created upon receipt of this row. + # flags is a dictionary containing at least the following + # attributes: + # isnew - indicates if the instance was newly created as a + # result of reading this row # instancekey - identity key of the instance # optional attribute: - # ispostselect - indicates if this row resulted from a 'post' select of additional tables/columns - + # ispostselect - indicates if this row resulted from a + # 'post' select of additional tables/columns + + def existing_execute(instance, row, **flags): + # process incoming instance and given row. the instance is + # "existing" and was created based on a previous row. + def post_execute(instance, **flags): - # process instance after all result rows have been processed. this - # function should be used to issue additional selections in order to - # eagerly load additional properties. - - return (execute, post_execute) - - either tuple value can also be ``None`` in which case no function is called. - + # process instance after all result rows have been processed. + # this function should be used to issue additional selections + # in order to eagerly load additional properties. + + return (new_execute, existing_execute, post_execute) + + Either of the three tuples can be ``None`` in which case no function + is called. """ - + raise NotImplementedError() - - def cascade_iterator(self, type, object, recursive=None, halt_on=None): - """return an iterator of objects which are child objects of the given object, - as attached to the attribute corresponding to this MapperProperty.""" - - return [] - def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None): - """run the given callable across all objects which are child objects of - the given object, as attached to the attribute corresponding to this MapperProperty.""" - - return [] + def cascade_iterator(self, type_, state, visited_instances=None, halt_on=None): + """Iterate through instances related to the given instance for + a particular 'cascade', starting with this MapperProperty. + + See PropertyLoader for the related instance implementation. + """ + + return iter([]) def get_criterion(self, query, key, value): """Return a ``WHERE`` clause suitable for this @@ -277,9 +354,9 @@ class MapperProperty(object): is a value to be matched. This is only picked up by ``PropertyLoaders``. - This is called by a ``Query``'s ``join_by`` method to - formulate a set of key/value pairs into a ``WHERE`` criterion - that spans multiple tables if needed. + This is called by a ``Query``'s ``join_by`` method to formulate a set + of key/value pairs into a ``WHERE`` criterion that spans multiple + tables if needed. """ return None @@ -298,10 +375,10 @@ class MapperProperty(object): def do_init(self): """Perform subclass-specific initialization steps. - - This is a *template* method called by the + + This is a *template* method called by the ``MapperProperty`` object's init() method.""" - + pass def register_dependencies(self, *args, **kwargs): @@ -323,7 +400,7 @@ class MapperProperty(object): level (as opposed to the individual instance level). """ - return self.parent._is_primary_mapper() + return not self.parent.non_primary def merge(self, session, source, dest): """Merge the attribute represented by this ``MapperProperty`` @@ -336,65 +413,95 @@ class MapperProperty(object): this ``MapperProperty`` to the given value, which may be a column value or an instance. 'operator' is an operator from the operators module, or from sql.Comparator. - + By default uses the PropComparator attached to this MapperProperty under the attribute name "comparator". """ return operator(self.comparator, value) -class PropComparator(sql.ColumnOperators): - """defines comparison operations for MapperProperty objects""" +class PropComparator(expression.ColumnOperators): + """defines comparison operations for MapperProperty objects. + PropComparator instances should also define an accessor 'property' + which returns the MapperProperty associated with this + PropComparator. + """ + def expression_element(self): return self.clause_element() - + def contains_op(a, b): return a.contains(b) contains_op = staticmethod(contains_op) - + def any_op(a, b, **kwargs): return a.any(b, **kwargs) any_op = staticmethod(any_op) - + def has_op(a, b, **kwargs): return a.has(b, **kwargs) has_op = staticmethod(has_op) - + def __init__(self, prop): - self.prop = prop + self.prop = self.property = prop + + def of_type_op(a, class_): + return a.of_type(class_) + of_type_op = staticmethod(of_type_op) + + def of_type(self, class_): + """Redefine this object in terms of a polymorphic subclass. + + Returns a new PropComparator from which further criterion can be evaluated. + e.g.:: + + query.join(Company.employees.of_type(Engineer)).\\ + filter(Engineer.name=='foo') + + \class_ + a class or mapper indicating that criterion will be against + this specific subclass. + + + """ + + return self.operate(PropComparator.of_type_op, class_) + def contains(self, other): - """return true if this collection contains other""" + """Return true if this collection contains other""" return self.operate(PropComparator.contains_op, other) def any(self, criterion=None, **kwargs): - """return true if this collection contains any member that meets the given criterion. - - criterion - an optional ClauseElement formulated against the member class' table or attributes. - - \**kwargs - key/value pairs corresponding to member class attribute names which will be compared - via equality to the corresponding values. + """Return true if this collection contains any member that meets the given criterion. + + criterion + an optional ClauseElement formulated against the member class' table + or attributes. + + \**kwargs + key/value pairs corresponding to member class attribute names which + will be compared via equality to the corresponding values. """ return self.operate(PropComparator.any_op, criterion, **kwargs) - + def has(self, criterion=None, **kwargs): - """return true if this element references a member which meets the given criterion. - - + """Return true if this element references a member which meets the given criterion. + criterion - an optional ClauseElement formulated against the member class' table or attributes. - + an optional ClauseElement formulated against the member class' table + or attributes. + \**kwargs - key/value pairs corresponding to member class attribute names which will be compared - via equality to the corresponding values. + key/value pairs corresponding to member class attribute names which + will be compared via equality to the corresponding values. """ return self.operate(PropComparator.has_op, criterion, **kwargs) - + + class StrategizedProperty(MapperProperty): """A MapperProperty which uses selectable strategies to affect loading behavior. @@ -405,7 +512,8 @@ class StrategizedProperty(MapperProperty): """ def _get_context_strategy(self, context): - return self._get_strategy(context.attributes.get(("loaderstrategy", self), self.strategy.__class__)) + path = context.path + return self._get_strategy(context.attributes.get(("loaderstrategy", path), self.strategy.__class__)) def _get_strategy(self, cls): try: @@ -413,7 +521,6 @@ class StrategizedProperty(MapperProperty): except KeyError: # cache the located strategy per class for faster re-lookup strategy = cls(self) - strategy.is_default = False strategy.init() self._all_strategies[cls] = strategy return strategy @@ -426,140 +533,134 @@ class StrategizedProperty(MapperProperty): def do_init(self): self._all_strategies = {} - self.strategy = self.create_strategy() - self._all_strategies[self.strategy.__class__] = self.strategy - self.strategy.init() + self.strategy = self._get_strategy(self.strategy_class) if self.is_primary(): self.strategy.init_class_attribute() -class LoaderStack(object): - """a stack object used during load operations to track the - current position among a chain of mappers to eager loaders.""" - - def __init__(self): - self.__stack = [] - - def push_property(self, key): - self.__stack.append(key) - - def push_mapper(self, mapper): - self.__stack.append(mapper.base_mapper()) - - def pop(self): - self.__stack.pop() - - def snapshot(self): - """return an 'snapshot' of this stack. - - this is a tuple form of the stack which can be used as a hash key.""" - return tuple(self.__stack) - - def __str__(self): - return "->".join([str(s) for s in self.__stack]) - -class OperationContext(object): - """Serve as a context during a query construction or instance - loading operation. +def build_path(mapper, key, prev=None): + if prev: + return prev + (mapper.base_mapper, key) + else: + return (mapper.base_mapper, key) - Accept ``MapperOption`` objects which may modify its state before proceeding. - """ +def serialize_path(path): + if path is None: + return None - def __init__(self, mapper, options): - self.mapper = mapper - self.options = options - self.attributes = {} - self.recursion_stack = util.Set() - for opt in util.flatten_iterator(options): - self.accept_option(opt) + return [ + (mapper.class_, mapper.entity_name, key) + for mapper, key in [(path[i], path[i+1]) for i in range(0, len(path)-1, 2)] + ] - def accept_option(self, opt): - pass +def deserialize_path(path): + if path is None: + return None -class MapperOption(object): - """Describe a modification to an OperationContext or Query.""" + global class_mapper + if class_mapper is None: + from sqlalchemy.orm import class_mapper - def process_query_context(self, context): - pass + return tuple( + chain(*[(class_mapper(cls, entity), key) for cls, entity, key in path]) + ) - def process_selection_context(self, context): - pass +class MapperOption(object): + """Describe a modification to a Query.""" def process_query(self, query): pass + def process_query_conditionally(self, query): + """same as process_query(), except that this option may not apply + to the given query. + + Used when secondary loaders resend existing options to a new + Query.""" + self.process_query(query) + class ExtensionOption(MapperOption): """a MapperOption that applies a MapperExtension to a query operation.""" - + def __init__(self, ext): self.ext = ext def process_query(self, query): query._extension = query._extension.copy() - query._extension.append(self.ext) + query._extension.insert(self.ext) -class SynonymProperty(MapperProperty): - def __init__(self, name, proxy=False): - self.name = name - self.proxy = proxy - - def setup(self, querycontext, **kwargs): - pass - - def create_row_processor(self, selectcontext, mapper, row): - return (None, None) - - def do_init(self): - if not self.proxy: - return - class SynonymProp(object): - def __set__(s, obj, value): - setattr(obj, self.name, value) - def __delete__(s, obj): - delattr(obj, self.name) - def __get__(s, obj, owner): - if obj is None: - return s - return getattr(obj, self.name) - setattr(self.parent.class_, self.key, SynonymProp()) - - def merge(self, session, source, dest, _recursive): - pass class PropertyOption(MapperOption): """A MapperOption that is applied to a property off the mapper or one of its child mappers, identified by a dot-separated key. """ - def __init__(self, key): + def __init__(self, key, mapper=None): self.key = key + self.mapper = mapper - def process_query_property(self, context, properties): - pass + def process_query(self, query): + self._process(query, True) - def process_selection_property(self, context, properties): - pass + def process_query_conditionally(self, query): + self._process(query, False) - def process_query_context(self, context): - self.process_query_property(context, self._get_properties(context)) + def _process(self, query, raiseerr): + if self._should_log_debug: + self.logger.debug("applying option to Query, property key '%s'" % self.key) + paths = self._get_paths(query, raiseerr) + if paths: + self.process_query_property(query, paths) - def process_selection_context(self, context): - self.process_selection_property(context, self._get_properties(context)) + def process_query_property(self, query, paths): + pass - def _get_properties(self, context): - try: - l = self.__prop - except AttributeError: - l = [] - mapper = context.mapper - for token in self.key.split('.'): - prop = mapper.get_property(token, resolve_synonyms=True) - l.append(prop) + def _get_paths(self, query, raiseerr): + path = None + l = [] + current_path = list(query._current_path) + + if self.mapper: + global class_mapper + if class_mapper is None: + from sqlalchemy.orm import class_mapper + mapper = self.mapper + if isinstance(self.mapper, type): + mapper = class_mapper(mapper) + if mapper is not query.mapper and mapper not in [q.mapper for q in query._entities]: + raise exceptions.ArgumentError("Can't find entity %s in Query. Current list: %r" % (str(mapper), [str(m) for m in query._entities])) + else: + mapper = query.mapper + if isinstance(self.key, basestring): + tokens = self.key.split('.') + else: + tokens = util.to_list(self.key) + + for token in tokens: + if isinstance(token, basestring): + prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr) + elif isinstance(token, PropComparator): + prop = token.property + token = prop.key + + else: + raise exceptions.ArgumentError("mapper option expects string key or list of attributes") + + if current_path and token == current_path[1]: + current_path = current_path[2:] + continue + + if prop is None: + return [] + path = build_path(mapper, prop.key, path) + l.append(path) + if getattr(token, '_of_type', None): + mapper = token._of_type + else: mapper = getattr(prop, 'mapper', None) - self.__prop = l return l PropertyOption.logger = logging.class_logger(PropertyOption) - +PropertyOption._should_log_debug = logging.is_debug_enabled(PropertyOption.logger) class AttributeExtension(object): """An abstract class which specifies `append`, `delete`, and `set` @@ -583,22 +684,13 @@ class StrategizedOption(PropertyOption): def is_chained(self): return False - - def process_query_property(self, context, properties): - self.logger.debug("applying option to QueryContext, property key '%s'" % self.key) - if self.is_chained(): - for prop in properties: - context.attributes[("loaderstrategy", prop)] = self.get_strategy_class() - else: - context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class() - def process_selection_property(self, context, properties): - self.logger.debug("applying option to SelectionContext, property key '%s'" % self.key) + def process_query_property(self, query, paths): if self.is_chained(): - for prop in properties: - context.attributes[("loaderstrategy", prop)] = self.get_strategy_class() - else: - context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class() + for path in paths: + query._attributes[("loaderstrategy", path)] = self.get_strategy_class() + else: + query._attributes[("loaderstrategy", paths[-1])] = self.get_strategy_class() def get_strategy_class(self): raise NotImplementedError() @@ -622,21 +714,21 @@ class LoaderStrategy(object): list of selected columns, *eager loading* properties may add ``LEFT OUTER JOIN`` clauses to the statement. - * it processes the SelectionContext at row-processing time. This - may involve setting instance-level lazyloader functions on newly - constructed instances, or may involve recursively appending - child items to a list in response to additionally eager-loaded - objects in the query. + * it processes the ``SelectionContext`` at row-processing time. This + includes straight population of attributes corresponding to rows, + setting instance-level lazyloader callables on newly + constructed instances, and appending child items to scalar/collection + attributes in response to eagerly-loaded relations. """ def __init__(self, parent): self.parent_property = parent - self.is_default = True + self.is_class_level = False def init(self): self.parent = self.parent_property.parent self.key = self.parent_property.key - + def init_class_attribute(self): pass @@ -644,12 +736,11 @@ class LoaderStrategy(object): pass def create_row_processor(self, selectcontext, mapper, row): - """return row processing functions which fulfill the contract specified + """Return row processing functions which fulfill the contract specified by MapperProperty.create_row_processor. - - - StrategizedProperty delegates its create_row_processor method - directly to this method. + + StrategizedProperty delegates its create_row_processor method directly + to this method. """ raise NotImplementedError() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 76cc412890..ba0644758f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1,89 +1,95 @@ # orm/mapper.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, logging -from sqlalchemy import sql_util as sqlutil -from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.util import ExtensionCarrier -from sqlalchemy.orm import sync -from sqlalchemy.orm.interfaces import MapperProperty, EXT_PASS, MapperExtension, SynonymProperty -import weakref, warnings, operator +"""Defines the [sqlalchemy.orm.mapper#Mapper] class, the central configurational +unit which associates a class with a database table. + +This is a semi-private module; the main configurational API of the ORM is +available in [sqlalchemy.orm#]. +""" -__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry'] +import weakref +from itertools import chain +from sqlalchemy import sql, util, exceptions, logging +from sqlalchemy.sql import expression, visitors, operators, util as sqlutil +from sqlalchemy.orm import sync, attributes +from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator +from sqlalchemy.orm.util import has_identity, _state_has_identity, _is_mapped_class, has_mapper, \ + _state_mapper, class_mapper, object_mapper, _class_to_mapper,\ + ExtensionCarrier, state_str, instance_str + +__all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry'] -# a dictionary mapping classes to their primary mappers -mapper_registry = weakref.WeakKeyDictionary() +_mapper_registry = weakref.WeakKeyDictionary() +_new_mappers = False +_already_compiling = False # a list of MapperExtensions that will be installed in all mappers by default global_extensions = [] -# a constant returned by get_attr_by_column to indicate +# a constant returned by _get_attr_by_column to indicate # this mapper is not handling an attribute for a particular # column -NO_ATTRIBUTE = object() +NO_ATTRIBUTE = util.symbol('NO_ATTRIBUTE') # lock used to synchronize the "mapper compile" step -_COMPILE_MUTEX = util.threading.Lock() +_COMPILE_MUTEX = util.threading.RLock() -# initialize these two lazily -attribute_manager = None +# initialize these lazily ColumnProperty = None +SynonymProperty = None +ComparableProperty = None +_expire_state = None + class Mapper(object): """Define the correlation of class attributes to database table columns. Instances of this class should be constructed via the - ``sqlalchemy.orm.mapper()`` function. + [sqlalchemy.orm#mapper()] function. """ def __init__(self, - class_, - local_table, - properties = None, - primary_key = None, - non_primary = False, - inherits = None, - inherit_condition = None, - extension = None, - order_by = False, - allow_column_override = False, - entity_name = None, - always_refresh = False, - version_id_col = None, - polymorphic_on=None, - _polymorphic_map=None, - polymorphic_identity=None, - polymorphic_fetch=None, - concrete=False, - select_table=None, - allow_null_pks=False, - batch=True, - column_prefix=None): + class_, + local_table, + properties = None, + primary_key = None, + non_primary = False, + inherits = None, + inherit_condition = None, + inherit_foreign_keys = None, + extension = None, + order_by = False, + allow_column_override = False, + entity_name = None, + always_refresh = False, + version_id_col = None, + polymorphic_on=None, + _polymorphic_map=None, + polymorphic_identity=None, + polymorphic_fetch=None, + concrete=False, + select_table=None, + with_polymorphic=None, + allow_null_pks=False, + batch=True, + column_prefix=None, + include_properties=None, + exclude_properties=None, + eager_defaults=False): """Construct a new mapper. - Mappers are normally constructed via the [sqlalchemy.orm#mapper()] + Mappers are normally constructed via the [sqlalchemy.orm#mapper()] function. See for details. + """ - if not issubclass(class_, object): - raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) - - for table in (local_table, select_table): - if table is not None and isinstance(table, sql._SelectBaseMixin): - # some db's, noteably postgres, dont want to select from a select - # without an alias. also if we make our own alias internally, then - # the configured properties on the mapper are not matched against the alias - # we make, theres workarounds but it starts to get really crazy (its crazy enough - # the SQL that gets generated) so just require an alias - raise exceptions.ArgumentError("Mapping against a Select object requires that it has a name. Use an alias to give it a name, i.e. s = select(...).alias('myselect')") - self.class_ = class_ self.entity_name = entity_name - self.class_key = ClassKey(class_, entity_name) self.primary_key_argument = primary_key self.non_primary = non_primary self.order_by = order_by @@ -92,21 +98,51 @@ class Mapper(object): self.concrete = concrete self.single = False self.inherits = inherits - self.select_table = select_table self.local_table = local_table self.inherit_condition = inherit_condition + self.inherit_foreign_keys = inherit_foreign_keys self.extension = extension - self.properties = properties or {} + self._init_properties = properties or {} self.allow_column_override = allow_column_override self.allow_null_pks = allow_null_pks self.delete_orphans = [] self.batch = batch + self.eager_defaults = eager_defaults self.column_prefix = column_prefix - # a Column which is used during a select operation to retrieve the - # "polymorphic identity" of the row, which indicates which Mapper should be used - # to construct a new object instance from that row. self.polymorphic_on = polymorphic_on self._eager_loaders = util.Set() + self._dependency_processors = [] + self._clause_adapter = None + self._requires_row_aliasing = False + self.__inherits_equated_pairs = None + + if not issubclass(class_, object): + raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) + + self.select_table = select_table + if select_table: + if with_polymorphic: + raise exceptions.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)") + self.with_polymorphic = ('*', select_table) + else: + if with_polymorphic == '*': + self.with_polymorphic = ('*', None) + elif isinstance(with_polymorphic, (tuple, list)): + if isinstance(with_polymorphic[0], (basestring, tuple, list)): + self.with_polymorphic = with_polymorphic + else: + self.with_polymorphic = (with_polymorphic, None) + elif with_polymorphic is not None: + raise exceptions.ArgumentError("Invalid setting for with_polymorphic") + else: + self.with_polymorphic = None + + if isinstance(self.local_table, expression._SelectBaseMixin): + util.warn("mapper %s creating an alias for the given selectable - use Class attributes for queries." % self) + self.local_table = self.local_table.alias() + + if self.with_polymorphic and isinstance(self.with_polymorphic[1], expression._SelectBaseMixin): + self.with_polymorphic[1] = self.with_polymorphic[1].alias() # our 'polymorphic identity', a string name that when located in a result set row # indicates this Mapper should be used to construct the object instance for that row. @@ -115,10 +151,10 @@ class Mapper(object): if polymorphic_fetch not in (None, 'union', 'select', 'deferred'): raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch) if polymorphic_fetch is None: - self.polymorphic_fetch = (self.select_table is None) and 'select' or 'union' + self.polymorphic_fetch = (self.with_polymorphic is None) and 'select' or 'union' else: self.polymorphic_fetch = polymorphic_fetch - + # a dictionary of 'polymorphic identity' names, associating those names with # Mappers that will be used to construct object instances upon a select operation. if _polymorphic_map is None: @@ -126,72 +162,54 @@ class Mapper(object): else: self.polymorphic_map = _polymorphic_map - class LOrderedProp(util.OrderedProperties): - """this extends OrderedProperties to trigger a compile() before the - members of the object are accessed.""" - def _get_data(s): - self.compile() - return s.__dict__['_data'] - _data = property(_get_data) - - self.columns = LOrderedProp() - self.c = self.columns + self.columns = self.c = util.OrderedProperties() - # each time the options() method is called, the resulting Mapper is - # stored in this dictionary based on the given options for fast re-access - self._options = {} + self.include_properties = include_properties + self.exclude_properties = exclude_properties # a set of all mappers which inherit from this one. self._inheriting_mappers = util.Set() - # a second mapper that is used for selecting, if the "select_table" argument - # was sent to this mapper. - self.__surrogate_mapper = None - - # whether or not our compile() method has been called already. - self.__is_compiled = False - - # if this mapper is to be a primary mapper (i.e. the non_primary flag is not set), - # associate this Mapper with the given class_ and entity name. subsequent - # calls to class_mapper() for the class_/entity name combination will return this - # mapper. - self._compile_class() + self.__props_init = False + self.__should_log_info = logging.is_info_enabled(self.logger) self.__should_log_debug = logging.is_debug_enabled(self.logger) - self.__log("constructed") - # uncomment to compile at construction time (the old way) - # this will break mapper setups that arent declared in the order - # of dependency - #self.compile() + self.__compile_class() + self.__compile_inheritance() + self.__compile_extensions() + self.__compile_properties() + self.__compile_pks() + global _new_mappers + _new_mappers = True + self.__log("constructed") def __log(self, msg): - self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.name or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "") + ") " + msg) + if self.__should_log_info: + self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg) def __log_debug(self, msg): - self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.name or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "") + ") " + msg) + if self.__should_log_debug: + self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg) def _is_orphan(self, obj): - optimistic = has_identity(obj) - for (key,klass) in self.delete_orphans: - if getattr(klass, key).hasparent(obj, optimistic=optimistic): - return False - else: - if len(self.delete_orphans): - if not has_identity(obj): - raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" % - ( - obj, - ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in self.delete_orphans]) - )) - else: - return True - else: - return False + o = False + for mapper in self.iterate_to_root(): + for (key,klass) in mapper.delete_orphans: + if attributes.has_parent(klass, obj, key, optimistic=has_identity(obj)): + return False + o = o or bool(mapper.delete_orphans) + return o def get_property(self, key, resolve_synonyms=False, raiseerr=True): - """return MapperProperty with the given key.""" + """return a MapperProperty associated with the given key.""" + self.compile() + return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr) + + def _get_property(self, key, resolve_synonyms=False, raiseerr=True): + """private in-compilation version of get_property().""" + prop = self.__props.get(key, None) if resolve_synonyms: while isinstance(prop, SynonymProperty): @@ -199,125 +217,205 @@ class Mapper(object): if prop is None and raiseerr: raise exceptions.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key)) return prop - + def iterate_properties(self): self.compile() return self.__props.itervalues() iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.") + + def __adjust_wp_selectable(self, spec=None, selectable=False): + """given a with_polymorphic() argument, resolve it against this mapper's with_polymorphic setting""" + + isdefault = False + if self.with_polymorphic: + isdefault = not spec and selectable is False + + if not spec: + spec = self.with_polymorphic[0] + if selectable is False: + selectable = self.with_polymorphic[1] + + return spec, selectable, isdefault + + def __mappers_from_spec(self, spec, selectable): + """given a with_polymorphic() argument, return the set of mappers it represents. + + Trims the list of mappers to just those represented within the given selectable, if present. + This helps some more legacy-ish mappings. + + """ + if spec == '*': + mappers = list(self.polymorphic_iterator()) + elif spec: + mappers = [_class_to_mapper(m) for m in util.to_list(spec)] + else: + mappers = [] + + if selectable: + tables = util.Set(sqlutil.find_tables(selectable)) + mappers = [m for m in mappers if m.local_table in tables] + + return mappers + __mappers_from_spec = util.conditional_cache_decorator(__mappers_from_spec) + + def __selectable_from_mappers(self, mappers): + """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), + construct an outerjoin amongst those mapper's mapped tables. + + """ + from_obj = self.mapped_table + for m in mappers: + if m is self: + continue + if m.concrete: + raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.") + elif not m.single: + from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition) + + return from_obj + __selectable_from_mappers = util.conditional_cache_decorator(__selectable_from_mappers) + def _with_polymorphic_mappers(self, spec=None, selectable=False): + spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) + return self.__mappers_from_spec(spec, selectable, cache=isdefault) + + def _with_polymorphic_selectable(self, spec=None, selectable=False): + spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) + if selectable: + return selectable + else: + return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable, cache=isdefault), cache=isdefault) + + def _with_polymorphic_args(self, spec=None, selectable=False): + spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) + mappers = self.__mappers_from_spec(spec, selectable, cache=isdefault) + if selectable: + return mappers, selectable + else: + return mappers, self.__selectable_from_mappers(mappers, cache=isdefault) + + def _iterate_polymorphic_properties(self, spec=None, selectable=False): + return iter(util.OrderedSet( + chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers(spec, selectable)]) + )) + + def properties(self): + raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.") + properties = property(properties) + + def compiled(self): + """return True if this mapper is compiled""" + return self.__props_init + compiled = property(compiled) + def dispose(self): - attribute_manager.reset_class_managed(self.class_) - if hasattr(self.class_, 'c'): + # disaable any attribute-based compilation + self.__props_init = True + try: del self.class_.c - if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'): - if self.class_.__init__._oldinit is not None: - self.class_.__init__ = self.class_.__init__._oldinit - else: - delattr(self.class_, '__init__') - - def compile(self): - """Compile this mapper into its final internal format. + except AttributeError: + pass + if not self.non_primary and self.entity_name in self._class_state.mappers: + del self._class_state.mappers[self.entity_name] + if not self._class_state.mappers: + attributes.unregister_class(self.class_) - This is the *external* version of the method which is not - reentrant. + def compile(self): + """Compile this mapper and all other non-compiled mappers. + + This method checks the local compiled status as well as for + any new mappers that have been defined, and is safe to call + repeatedly. """ - - if self.__is_compiled: + + global _new_mappers + if self.__props_init and not _new_mappers: return self + _COMPILE_MUTEX.acquire() + global _already_compiling + if _already_compiling: + self.__initialize_properties() + return + _already_compiling = True try: + # double-check inside mutex - if self.__is_compiled: + if self.__props_init and not _new_mappers: return self - self._compile_all() - # if we're not primary, compile us - if self.non_primary: - self._do_compile() - self._initialize_properties() + # initialize properties on all mappers + for mapper in list(_mapper_registry): + if not mapper.__props_init: + mapper.__initialize_properties() + _new_mappers = False return self finally: + _already_compiling = False _COMPILE_MUTEX.release() - def _compile_all(self): - # compile all primary mappers - for mapper in mapper_registry.values(): - if not mapper.__is_compiled: - mapper._do_compile() - - # initialize properties on all mappers - for mapper in mapper_registry.values(): - if not mapper.__props_init: - mapper._initialize_properties() - - def _check_compile(self): - if self.non_primary: - self._do_compile() - self._initialize_properties() - return self + def __initialize_properties(self): + """Call the ``init()`` method on all ``MapperProperties`` + attached to this mapper. + + This is a deferred configuration step which is intended + to execute once all mappers have been constructed. + """ - def _do_compile(self): - """Compile this mapper into its final internal format. + self.__log("__initialize_properties() started") + l = [(key, prop) for key, prop in self.__props.iteritems()] + for key, prop in l: + self.__log("initialize prop " + key) + if getattr(prop, 'key', None) is None: + prop.init(key, self) + self.__log("__initialize_properties() complete") + self.__props_init = True - This is the *internal* version of the method which is assumed - to be called within compile() and is reentrant. - """ - if self.__is_compiled: - return self - self.__log("_do_compile() started") - self.__is_compiled = True - self.__props_init = False - self._compile_extensions() - self._compile_inheritance() - self._compile_tables() - self._compile_properties() - self._compile_selectable() - self.__log("_do_compile() complete") - return self - - def _compile_extensions(self): + def __compile_extensions(self): """Go through the global_extensions list as well as the list of ``MapperExtensions`` specified for this ``Mapper`` and creates a linked list of those extensions. """ - extlist = util.Set() - for ext_class in global_extensions: - if isinstance(ext_class, MapperExtension): - extlist.add(ext_class) - else: - extlist.add(ext_class()) + extlist = util.OrderedSet() extension = self.extension - if extension is not None: + if extension: for ext_obj in util.to_list(extension): + # local MapperExtensions have already instrumented the class extlist.add(ext_obj) + if self.inherits: + for ext in self.inherits.extension: + if ext not in extlist: + extlist.add(ext) + ext.instrument_class(self, self.class_) + else: + for ext in global_extensions: + if isinstance(ext, type): + ext = ext() + if ext not in extlist: + extlist.add(ext) + ext.instrument_class(self, self.class_) + self.extension = ExtensionCarrier() for ext in extlist: self.extension.append(ext) - def _compile_inheritance(self): - """Determine if this Mapper inherits from another mapper, and - if so calculates the mapped_table for this Mapper taking the - inherited mapper into account. - - For joined table inheritance, creates a ``SyncRule`` that will - synchronize column values between the joined tables. also - initializes polymorphic variables used in polymorphic loads. - """ + def __compile_inheritance(self): + """Configure settings related to inherting and/or inherited mappers being present.""" - if self.inherits is not None: + if self.inherits: if isinstance(self.inherits, type): - self.inherits = class_mapper(self.inherits, compile=False)._do_compile() + self.inherits = class_mapper(self.inherits, compile=False) else: - self.inherits = self.inherits._do_compile() + self.inherits = self.inherits if not issubclass(self.class_, self.inherits.class_): raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__)) - if self._is_primary_mapper() != self.inherits._is_primary_mapper(): - np = self._is_primary_mapper() and "primary" or "non-primary" + if self.non_primary != self.inherits.non_primary: + np = not self.non_primary and "primary" or "non-primary" raise exceptions.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np)) # inherit_condition is optional. if self.local_table is None: @@ -325,171 +423,131 @@ class Mapper(object): self.single = True if not self.local_table is self.inherits.local_table: if self.concrete: - self._synchronizer= None self.mapped_table = self.local_table + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on: + mapper._requires_row_aliasing = True else: if self.inherit_condition is None: # figure out inherit condition from our table to the immediate table # of the inherited mapper, not its full table which could pull in other # stuff we dont want (allows test/inheritance.InheritTest4 to pass) - self.inherit_condition = sql.join(self.inherits.local_table, self.local_table).onclause + self.inherit_condition = sqlutil.join_condition(self.inherits.local_table, self.local_table) self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition) - # generate sync rules. similarly to creating the on clause, specify a - # stricter set of tables to create "sync rules" by,based on the immediate - # inherited table, rather than all inherited tables - self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - self._synchronizer.compile(self.mapped_table.onclause) + + fks = util.to_set(self.inherit_foreign_keys) + self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks) else: - self._synchronizer = None self.mapped_table = self.local_table if self.polymorphic_identity is not None: - self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self) + self.inherits.polymorphic_map[self.polymorphic_identity] = self if self.polymorphic_on is None: - if self.inherits.polymorphic_on is not None: - self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False) + for mapper in self.iterate_to_root(): + # try to set up polymorphic on using correesponding_column(); else leave + # as None + if mapper.polymorphic_on: + self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on) + break else: + # TODO: this exception not covered raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) - if self.polymorphic_identity is not None and not self.concrete: + if self.polymorphic_identity and not self.concrete: self._identity_class = self.inherits._identity_class else: self._identity_class = self.class_ - + + if self.version_id_col is None: + self.version_id_col = self.inherits.version_id_col + + for mapper in self.iterate_to_root(): + util.reset_cached(mapper, '_equivalent_columns') + if self.order_by is False: self.order_by = self.inherits.order_by self.polymorphic_map = self.inherits.polymorphic_map self.batch = self.inherits.batch self.inherits._inheriting_mappers.add(self) + self.base_mapper = self.inherits.base_mapper + self._all_tables = self.inherits._all_tables else: - self._synchronizer = None + self._all_tables = util.Set() + self.base_mapper = self self.mapped_table = self.local_table - if self.polymorphic_identity is not None: + if self.polymorphic_identity: if self.polymorphic_on is None: raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) - self._add_polymorphic_mapping(self.polymorphic_identity, self) + self.polymorphic_map[self.polymorphic_identity] = self self._identity_class = self.class_ - + if self.mapped_table is None: raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self)) - # convert polymorphic class associations to mappers - for key in self.polymorphic_map.keys(): - if isinstance(self.polymorphic_map[key], type): - self.polymorphic_map[key] = class_mapper(self.polymorphic_map[key]) - - def _add_polymorphic_mapping(self, key, class_or_mapper, entity_name=None): - """Add a Mapper to our *polymorphic map*.""" + def __compile_pks(self): - if isinstance(class_or_mapper, type): - class_or_mapper = class_mapper(class_or_mapper, entity_name=entity_name) - self.polymorphic_map[key] = class_or_mapper + self.tables = sqlutil.find_tables(self.mapped_table) - def _compile_tables(self): - """After the inheritance relationships have been reconciled, - set up some more table-based instance variables and determine - the *primary key* columns for all tables represented by this - ``Mapper``. - """ - - # summary of the various Selectable units: - # mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table) - # local_table - the Selectable that was passed to this Mapper's constructor, if any - # select_table - the Selectable that will be used during queries. if this is specified - # as a constructor keyword argument, it takes precendence over mapped_table, otherwise its mapped_table - # this is either select_table if it was given explicitly, or in the case of a mapper that inherits - # its local_table - # tables - a collection of underlying Table objects pulled from mapped_table + if not self.tables: + raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) - if self.select_table is None: - self.select_table = self.mapped_table + self._pks_by_table = {} + self._cols_by_table = {} - # locate all tables contained within the "table" passed in, which - # may be a join or other construct - self.tables = sqlutil.TableFinder(self.mapped_table) + all_cols = util.Set(chain(*[col.proxy_set for col in self._columntoproperty])) + pk_cols = util.Set([c for c in all_cols if c.primary_key]) - # determine primary key columns - self.pks_by_table = {} + # identify primary key columns which are also mapped by this mapper. + for t in util.Set(self.tables + [self.mapped_table]): + self._all_tables.add(t) + if t.primary_key and pk_cols.issuperset(t.primary_key): + # ordering is important since it determines the ordering of mapper.primary_key (and therefore query.get()) + self._pks_by_table[t] = util.OrderedSet(t.primary_key).intersection(pk_cols) + self._cols_by_table[t] = util.OrderedSet(t.c).intersection(all_cols) - # go through all of our represented tables - # and assemble primary key columns - for t in self.tables + [self.mapped_table]: - try: - l = self.pks_by_table[t] - except KeyError: - l = self.pks_by_table.setdefault(t, util.OrderedSet()) - for k in t.primary_key: - l.add(k) - - if self.primary_key_argument is not None: + # if explicit PK argument sent, add those columns to the primary key mappings + if self.primary_key_argument: for k in self.primary_key_argument: - self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k) - - if len(self.pks_by_table[self.mapped_table]) == 0: - raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) + if k.table not in self._pks_by_table: + self._pks_by_table[k.table] = util.OrderedSet() + self._pks_by_table[k.table].add(k) + + if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0: + raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) - if self.inherits is not None and not self.concrete and not self.primary_key_argument: + if self.inherits and not self.concrete and not self.primary_key_argument: + # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit) self.primary_key = self.inherits.primary_key - self._get_clause = self.inherits._get_clause else: - # create the "primary_key" for this mapper. this will flatten "equivalent" primary key columns - # into one column, where "equivalent" means that one column references the other via foreign key, or - # multiple columns that all reference a common parent column. it will also resolve the column - # against the "mapped_table" of this mapper. - equivalent_columns = self._get_equivalent_columns() - - primary_key = sql.ColumnSet() - - for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): - c = self.mapped_table.corresponding_column(col, raiseerr=False) - if c is None: - for cc in equivalent_columns[col]: - c = self.mapped_table.corresponding_column(cc, raiseerr=False) - if c is not None: - break - else: - raise exceptions.ArgumentError("Cant resolve column " + str(col)) - - # this step attempts to resolve the column to an equivalent which is not - # a foreign key elsewhere. this helps with joined table inheritance - # so that PKs are expressed in terms of the base table which is always - # present in the initial select - # TODO: this is a little hacky right now, the "tried" list is to prevent - # endless loops between cyclical FKs, try to make this cleaner/work better/etc., - # perhaps via topological sort (pick the leftmost item) - tried = util.Set() - while True: - if not len(c.foreign_keys) or c in tried: - break - for cc in c.foreign_keys: - cc = cc.column - c2 = self.mapped_table.corresponding_column(cc, raiseerr=False) - if c2 is not None: - c = c2 - tried.add(c) - break - else: - break - primary_key.add(c) - + # determine primary key from argument or mapped_table pks - reduce to the minimal set of columns + if self.primary_key_argument: + primary_key = sqlutil.reduce_columns([self.mapped_table.corresponding_column(c) for c in self.primary_key_argument]) + else: + primary_key = sqlutil.reduce_columns(self._pks_by_table[self.mapped_table]) + if len(primary_key) == 0: - raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) + raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) self.primary_key = primary_key self.__log("Identified primary key columns: " + str(primary_key)) - - _get_clause = sql.and_() - for primary_key in self.primary_key: - _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True)) - self._get_clause = _get_clause - def _get_equivalent_columns(self): + def _get_clause(self): + """create a "get clause" based on the primary key. this is used + by query.get() and many-to-one lazyloads to load this item + by primary key. + + """ + params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key] + return sql.and_(*[k==v for (k, v) in params]), dict(params) + _get_clause = property(util.cache_decorator(_get_clause)) + + def _equivalent_columns(self): """Create a map of all *equivalent* columns, based on the determination of column pairs that are equated to one another either by an established foreign key relationship or by a joined-table inheritance join. - This is used to determine the minimal set of primary key - columns for the mapper, as well as when relating + This is used to determine the minimal set of primary key + columns for the mapper, as well as when relating columns to those of a polymorphic selectable (i.e. a UNION of several mapped tables), as that selectable usually only contains one column in its columns clause out of a group of several which @@ -499,20 +557,17 @@ class Mapper(object): to lists of equivalent columns, i.e. { - tablea.col1: + tablea.col1: set([tableb.col1, tablec.col1]), tablea.col2: set([tabled.col2]) } - - this method is called repeatedly during the compilation process as - the resulting dictionary contains more equivalents as more inheriting - mappers are compiled. the repetition process may be open to some optimization. + """ result = {} def visit_binary(binary): - if binary.operator == operator.eq: + if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: @@ -521,30 +576,67 @@ class Mapper(object): result[binary.right].add(binary.left) else: result[binary.right] = util.Set([binary.left]) - vis = mapperutil.BinaryVisitor(visit_binary) + for mapper in self.base_mapper.polymorphic_iterator(): + if mapper.inherit_condition: + visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary) - for mapper in self.base_mapper().polymorphic_iterator(): - if mapper.inherit_condition is not None: - vis.traverse(mapper.inherit_condition) + # TODO: matching of cols to foreign keys might better be generalized + # into general column translation (i.e. corresponding_column) - for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): - if not len(col.foreign_keys): - result.setdefault(col, util.Set()).add(col) - else: - for fk in col.foreign_keys: - result.setdefault(fk.column, util.Set()).add(col) + # recursively descend into the foreign key collection of the given column + # and assemble each FK-related col as an "equivalent" for the given column + def equivs(col, recursive, equiv): + if col in recursive: + return + recursive.add(col) + for fk in col.foreign_keys: + if fk.column not in result: + result[fk.column] = util.Set() + result[fk.column].add(equiv) + equivs(fk.column, recursive, col) + + for column in (self.primary_key_argument or self._pks_by_table[self.mapped_table]): + for col in column.proxy_set: + if not col.foreign_keys: + if col not in result: + result[col] = util.Set() + result[col].add(col) + else: + equivs(col, util.Set(), col) return result - - def _compile_properties(self): - """Inspect the properties dictionary sent to the Mapper's - constructor as well as the mapped_table, and create - ``MapperProperty`` objects corresponding to each mapped column - and relation. - - Also grab ``MapperProperties`` from the inherited mapper, if - any, and create copies of them to attach to this Mapper. - """ + _equivalent_columns = property(util.cache_decorator(_equivalent_columns)) + + class _CompileOnAttr(PropComparator): + """A placeholder descriptor which triggers compilation on access.""" + + def __init__(self, class_, key): + self.class_ = class_ + self.key = key + self.existing_prop = getattr(class_, key, None) + + def __getattribute__(self, key): + cls = object.__getattribute__(self, 'class_') + clskey = object.__getattribute__(self, 'key') + + if key.startswith('__'): + return object.__getattribute__(self, key) + + class_mapper(cls) + + if cls.__dict__.get(clskey) is self: + # FIXME: there should not be any scenarios where + # a mapper compile leaves this CompileOnAttr in + # place. + util.warn( + ("Attribute '%s' on class '%s' was not replaced during " + "mapper compilation operation") % (clskey, cls.__name__)) + # clean us up explicitly + delattr(cls, clskey) + + return getattr(getattr(cls, clskey), key) + + def __compile_properties(self): # object attribute names mapped to MapperProperty objects self.__props = util.OrderedDict() @@ -552,93 +644,135 @@ class Mapper(object): # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as # populating multiple object attributes - self.columntoproperty = mapperutil.TranslatingDict(self.mapped_table) + self._columntoproperty = {} # load custom properties - if self.properties is not None: - for key, prop in self.properties.iteritems(): + if self._init_properties: + for key, prop in self._init_properties.iteritems(): self._compile_property(key, prop, False) - if self.inherits is not None: + # pull properties from the inherited mapper if any. + if self.inherits: for key, prop in self.inherits.__props.iteritems(): - if not self.__props.has_key(key): + if key not in self.__props: self._adapt_inherited_property(key, prop) - # load properties from the main table object, - # not overriding those set up in the 'properties' argument + # create properties for each column in the mapped table, + # for those columns which don't already map to a property for column in self.mapped_table.columns: - if self.columntoproperty.has_key(column): + if column in self._columntoproperty: + continue + + if (self.include_properties is not None and + column.key not in self.include_properties): + self.__log("not including property %s" % (column.key)) + continue + + if (self.exclude_properties is not None and + column.key in self.exclude_properties): + self.__log("excluding property %s" % (column.key)) continue - if not self.columns.has_key(column.key): - self.columns[column.key] = self.select_table.corresponding_column(column, keys_ok=True, raiseerr=True) column_key = (self.column_prefix or '') + column.key - prop = self.__props.get(column.key, None) - if prop is None: - prop = ColumnProperty(column) - self.__props[column_key] = prop - prop.set_parent(self) - self.__log("adding ColumnProperty %s" % (column_key)) - elif isinstance(prop, ColumnProperty): + + self._compile_property(column_key, column, init=False, setparent=True) + + # do a special check for the "discriminiator" column, as it may only be present + # in the 'with_polymorphic' selectable but we need it for the base mapper + if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty: + col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on + self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True) + + def _adapt_inherited_property(self, key, prop): + if not self.concrete: + self._compile_property(key, prop, init=False, setparent=False) + # TODO: concrete properties dont adapt at all right now....will require copies of relations() etc. + + def _compile_property(self, key, prop, init=True, setparent=True): + self.__log("_compile_property(%s, %s)" % (key, prop.__class__.__name__)) + + if not isinstance(prop, MapperProperty): + # we were passed a Column or a list of Columns; generate a ColumnProperty + columns = util.to_list(prop) + column = columns[0] + if not expression.is_column(column): + raise exceptions.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop)) + + prop = self.__props.get(key, None) + + if isinstance(prop, ColumnProperty): + # TODO: the "property already exists" case is still not well defined here. + # assuming single-column, etc. + if prop.parent is not self: + # existing ColumnProperty from an inheriting mapper. + # make a copy and append our column to it prop = prop.copy() - prop.set_parent(self) - self.__props[column_key] = prop - if column in self.primary_key and prop.columns[-1] in self.primary_key: - warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'. Use explicit properties to give each column its own mapped attribute name." % (str(self), str(column), str(prop.columns[-1]), column_key))) prop.columns.append(column) - self.__log("appending to existing ColumnProperty %s" % (column_key)) + self.__log("appending to existing ColumnProperty %s" % (key)) + elif prop is None: + mapped_column = [] + for c in columns: + mc = self.mapped_table.corresponding_column(c) + if not mc: + raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c)) + mapped_column.append(mc) + prop = ColumnProperty(*mapped_column) else: if not self.allow_column_override: raise exceptions.ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop))) else: - continue - - # its a ColumnProperty - match the ultimate table columns - # back to the property - self.columntoproperty.setdefault(column, []).append(prop) - + return - def _initialize_properties(self): - """Call the ``init()`` method on all ``MapperProperties`` - attached to this mapper. + if isinstance(prop, ColumnProperty): + col = self.mapped_table.corresponding_column(prop.columns[0]) + # col might not be present! the selectable given to the mapper need not include "deferred" + # columns (included in zblog tests) + if col is None: + col = prop.columns[0] + else: + # if column is coming in after _cols_by_table was initialized, ensure the col is in the + # right set + if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]: + self._cols_by_table[col.table].add(col) - This happens after all mappers have completed compiling - everything else up until this point, so that all dependencies - are fully available. - """ + self.columns[key] = col + for col in prop.columns: + for col in col.proxy_set: + self._columntoproperty[col] = prop + + + elif isinstance(prop, SynonymProperty) and setparent: + if prop.descriptor is None: + prop.descriptor = getattr(self.class_, key, None) + if isinstance(prop.descriptor, Mapper._CompileOnAttr): + prop.descriptor = object.__getattribute__(prop.descriptor, 'existing_prop') + if prop.map_column: + if not key in self.mapped_table.c: + raise exceptions.ArgumentError("Can't compile synonym '%s': no column on table '%s' named '%s'" % (prop.name, self.mapped_table.description, key)) + self._compile_property(prop.name, ColumnProperty(self.mapped_table.c[key]), init=init, setparent=setparent) + elif isinstance(prop, ComparableProperty) and setparent: + # refactor me + if prop.descriptor is None: + prop.descriptor = getattr(self.class_, key, None) + if isinstance(prop.descriptor, Mapper._CompileOnAttr): + prop.descriptor = object.__getattribute__(prop.descriptor, + 'existing_prop') + self.__props[key] = prop - self.__log("_initialize_properties() started") - l = [(key, prop) for key, prop in self.__props.iteritems()] - for key, prop in l: - if getattr(prop, 'key', None) is None: - prop.init(key, self) - self.__log("_initialize_properties() complete") - self.__props_init = True + if setparent: + prop.set_parent(self) - def _compile_selectable(self): - """If the 'select_table' keyword argument was specified, set - up a second *surrogate mapper* that will be used for select - operations. + if not self.non_primary: + setattr(self.class_, key, Mapper._CompileOnAttr(self.class_, key)) - The columns of `select_table` should encompass all the columns - of the `mapped_table` either directly or through proxying - relationships. Currently, non-column properties are **not** - copied. This implies that a polymorphic mapper can't do any - eager loading right now. - """ + if init: + prop.init(key, self) + + for mapper in self._inheriting_mappers: + mapper._adapt_inherited_property(key, prop) - if self.select_table is not self.mapped_table: - props = {} - if self.properties is not None: - for key, prop in self.properties.iteritems(): - if sql.is_column(prop): - props[key] = self.select_table.corresponding_column(prop) - elif (isinstance(prop, list) and sql.is_column(prop[0])): - props[key] = [self.select_table.corresponding_column(c) for c in prop] - self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument) - - def _compile_class(self): + def __compile_class(self): """If this mapper is to be a primary mapper (i.e. the non_primary flag is not set), associate this Mapper with the given class_ and entity name. @@ -650,72 +784,52 @@ class Mapper(object): """ if self.non_primary: + if not hasattr(self.class_, '_class_state'): + raise exceptions.InvalidRequestError("Class %s has no primary mapper configured. Configure a primary mapper first before setting up a non primary Mapper.") + self._class_state = self.class_._class_state + _mapper_registry[self] = True return - if not self.non_primary and (mapper_registry.has_key(self.class_key)): - raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper, or to create a new primary mapper, remove this mapper first via sqlalchemy.orm.clear_mapper(mapper), or preferably sqlalchemy.orm.clear_mappers() to clear all mappers." % (self.class_, self.entity_name)) + if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers): + raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper. clear_mappers() will remove *all* current mappers from all classes." % (self.class_, self.entity_name)) - attribute_manager.reset_class_managed(self.class_) - - oldinit = self.class_.__init__ - def init(instance, *args, **kwargs): + def extra_init(class_, oldinit, instance, args, kwargs): self.compile() - self.extension.init_instance(self, self.class_, instance, args, kwargs) + if 'init_instance' in self.extension.methods: + self.extension.init_instance(self, class_, oldinit, instance, args, kwargs) - if oldinit is not None: - try: - oldinit(instance, *args, **kwargs) - except: - # call init_failed but suppress exceptions into warnings so that original __init__ - # exception is raised - util.warn_exception(self.extension.init_failed, self, self.class_, instance, args, kwargs) - raise - - # override oldinit, ensuring that its not already a Mapper-decorated init method - if oldinit is None or not hasattr(oldinit, '_oldinit'): - try: - init.__name__ = oldinit.__name__ - init.__doc__ = oldinit.__doc__ - except: - # cant set __name__ in py 2.3 ! - pass - init._oldinit = oldinit - self.class_.__init__ = init + def on_exception(class_, oldinit, instance, args, kwargs): + util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs) - _COMPILE_MUTEX.acquire() - try: - mapper_registry[self.class_key] = self - finally: - _COMPILE_MUTEX.release() + attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes) - if self.entity_name is None: - self.class_.c = self.c + self._class_state = self.class_._class_state + _mapper_registry[self] = True - def base_mapper(self): - """Return the ultimate base mapper in an inheritance chain.""" + self.class_._class_state.mappers[self.entity_name] = self - # TODO: calculate this at mapper setup time - if self.inherits is not None: - return self.inherits.base_mapper() - else: - return self + for ext in util.to_list(self.extension, []): + ext.instrument_class(self, self.class_) + + if self.entity_name is None: + self.class_.c = self.c def common_parent(self, other): """Return true if the given mapper shares a common inherited parent as this mapper.""" - return self.base_mapper() is other.base_mapper() + return self.base_mapper is other.base_mapper def isa(self, other): """Return True if the given mapper inherits from this mapper.""" m = other - while m is not self and m.inherits is not None: + while m is not self and m.inherits: m = m.inherits return m is self def iterate_to_root(self): m = self - while m is not None: + while m: yield m m = m.inherits @@ -727,7 +841,7 @@ class Mapper(object): all their inheriting mappers as well. To iterate through an entire hierarchy, use - ``mapper.base_mapper().polymorphic_iterator()``.""" + ``mapper.base_mapper.polymorphic_iterator()``.""" yield self for mapper in self._inheriting_mappers: @@ -744,7 +858,7 @@ class Mapper(object): self.add_property(key, value) def add_property(self, key, prop): - """Add an indiviual MapperProperty to this mapper. + """Add an individual MapperProperty to this mapper. If the mapper has not been compiled yet, just adds the property to the initial properties dictionary sent to the @@ -752,100 +866,15 @@ class Mapper(object): the given MapperProperty is compiled immediately. """ - self.properties[key] = prop - if self.__is_compiled: - # if we're compiled, make sure all the other mappers are compiled too - self._compile_all() - self._compile_property(key, prop, init=True) - - def _create_prop_from_column(self, column): - column = util.to_list(column) - if not sql.is_column(column[0]): - return None - mapped_column = [] - for c in column: - mc = self.mapped_table.corresponding_column(c, raiseerr=False) - if not mc: - raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c)) - mapped_column.append(mc) - return ColumnProperty(*mapped_column) - - def _adapt_inherited_property(self, key, prop): - if not self.concrete: - self._compile_property(key, prop, init=False, setparent=False) - # TODO: concrete properties dont adapt at all right now....will require copies of relations() etc. - - def _compile_property(self, key, prop, init=True, setparent=True): - """Add a ``MapperProperty`` to this or another ``Mapper``, - including configuration of the property. - - The properties' parent attribute will be set, and the property - will also be copied amongst the mappers which inherit from - this one. - - If the given `prop` is a ``Column`` or list of Columns, a - ``ColumnProperty`` will be created. - """ - - self.__log("_compile_property(%s, %s)" % (key, prop.__class__.__name__)) - - if not isinstance(prop, MapperProperty): - col = self._create_prop_from_column(prop) - if col is None: - raise exceptions.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop)) - prop = col - - self.__props[key] = prop - if setparent: - prop.set_parent(self) - - if isinstance(prop, ColumnProperty): - # relate the mapper's "select table" to the given ColumnProperty - col = self.select_table.corresponding_column(prop.columns[0], keys_ok=True, raiseerr=False) - # col might not be present! the selectable given to the mapper need not include "deferred" - # columns (included in zblog tests) - if col is None: - col = prop.columns[0] - self.columns[key] = col - for col in prop.columns: - proplist = self.columntoproperty.setdefault(col, []) - proplist.append(prop) - - if init: - prop.init(key, self) - - for mapper in self._inheriting_mappers: - mapper._adapt_inherited_property(key, prop) + self._init_properties[key] = prop + self._compile_property(key, prop, init=self.__props_init) def __str__(self): - return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.encodedname or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "") - - def _is_primary_mapper(self): - """Return True if this mapper is the primary mapper for its class key (class + entity_name).""" - return mapper_registry.get(self.class_key, None) is self + return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") def primary_mapper(self): """Return the primary mapper corresponding to this mapper's class key (class + entity_name).""" - return mapper_registry[self.class_key] - - def is_assigned(self, instance): - """Return True if this mapper handles the given instance. - - This is dependent not only on class assignment but the - optional `entity_name` parameter as well. - """ - - return instance.__class__ is self.class_ and getattr(instance, '_entity_name', None) == self.entity_name - - def _assign_entity_name(self, instance): - """Assign this Mapper's entity name to the given instance. - - Subsequent Mapper lookups for this instance will return the - primary mapper corresponding to this Mapper's class and entity - name. - """ - - instance._entity_name = self.entity_name + return self._class_state.mappers[self.entity_name] def get_session(self): """Return the contextual session provided by the mapper @@ -855,26 +884,23 @@ class Mapper(object): from the extension chain. """ - self.compile() - s = self.extension.get_session() - if s is EXT_PASS: - raise exceptions.InvalidRequestError("No contextual Session is established. Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.") - return s - - def has_eager(self): - """Return True if one of the properties attached to this - Mapper is eager loading. - """ + if 'get_session' in self.extension.methods: + s = self.extension.get_session() + if s is not EXT_CONTINUE: + return s - return len(self._eager_loaders) > 0 + raise exceptions.InvalidRequestError("No contextual Session is established.") def instances(self, cursor, session, *mappers, **kwargs): """Return a list of mapped instances corresponding to the rows in a given ResultProxy. + + DEPRECATED. """ import sqlalchemy.orm.query return sqlalchemy.orm.Query(self, session).instances(cursor, *mappers, **kwargs) + instances = util.deprecated(None, False)(instances) def identity_key_from_row(self, row): """Return an identity-map key for use in storing/retrieving an @@ -905,292 +931,298 @@ class Mapper(object): """ return self.identity_key_from_primary_key(self.primary_key_from_instance(instance)) + def _identity_key_from_state(self, state): + return self.identity_key_from_primary_key(self._primary_key_from_state(state)) + def primary_key_from_instance(self, instance): """Return the list of primary key values for the given instance. """ - return [self.get_attr_by_column(instance, column) for column in self.primary_key] + return [self._get_state_attr_by_column(instance._state, column) for column in self.primary_key] - def canload(self, instance): - """return true if this mapper is capable of loading the given instance""" - if self.polymorphic_on is not None: - return isinstance(instance, self.class_) + def _primary_key_from_state(self, state): + return [self._get_state_attr_by_column(state, column) for column in self.primary_key] + + def _canload(self, state): + if self.polymorphic_on: + return issubclass(state.class_, self.class_) else: - return instance.__class__ is self.class_ - - def _getpropbycolumn(self, column, raiseerror=True): + return state.class_ is self.class_ + + def _get_col_to_prop(self, column): try: - prop = self.columntoproperty[column] + return self._columntoproperty[column] except KeyError: - try: - prop = self.__props[column.key] - if not raiseerror: - return None - raise exceptions.InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) - except KeyError: - if not raiseerror: - return None - raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) - return prop[0] + prop = self.__props.get(column.key, None) + if prop: + raise exceptions.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) + else: + raise exceptions.UnmappedColumnError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) + + def _get_state_attr_by_column(self, state, column): + return self._get_col_to_prop(column).getattr(state, column) - def get_attr_by_column(self, obj, column, raiseerror=True): - """Return an instance attribute using a Column as the key.""" + def _set_state_attr_by_column(self, state, column, value): + return self._get_col_to_prop(column).setattr(state, value, column) - prop = self._getpropbycolumn(column, raiseerror) - if prop is None: - return NO_ATTRIBUTE - return prop.getattr(obj, column) + def _get_attr_by_column(self, obj, column): + return self._get_col_to_prop(column).getattr(obj._state, column) - def set_attr_by_column(self, obj, column, value): - """Set the value of an instance attribute using a Column as the key.""" + def _get_committed_attr_by_column(self, obj, column): + return self._get_col_to_prop(column).getcommitted(obj._state, column) - self.columntoproperty[column][0].setattr(obj, value, column) + def _set_attr_by_column(self, obj, column, value): + self._get_col_to_prop(column).setattr(obj._state, column, value) - def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False): + def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. This is called within the context of a UOWTransaction during a flush operation. - `save_obj` issues SQL statements not just for instances mapped + `_save_obj` issues SQL statements not just for instances mapped directly by this mapper, but for instances mapped by all inheriting mappers as well. This is to maintain proper insert ordering among a polymorphic chain of instances. Therefore - save_obj is typically called only on a *base mapper*, or a + _save_obj is typically called only on a *base mapper*, or a mapper which does not inherit from any other mapper. """ if self.__should_log_debug: - self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched")) + self.__log_debug("_save_obj() start, " + (single and "non-batched" or "batched")) - # if batch=false, call save_obj separately for each object + # if batch=false, call _save_obj separately for each object if not single and not self.batch: - for obj in objects: - self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) + for state in states: + self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return + # if session has a connection callable, + # organize individual states with the connection to use for insert/update if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(obj, connection_callable(self, obj)) for obj in objects] + tups = [(state, connection_callable(self, state.obj()), _state_has_identity(state)) for state in states] else: connection = uowtransaction.transaction.connection(self) - tups = [(obj, connection) for obj in objects] - + tups = [(state, connection, _state_has_identity(state)) for state in states] + if not postupdate: - for obj, connection in tups: - if not has_identity(obj): - for mapper in object_mapper(obj).iterate_to_root(): - mapper.extension.before_insert(mapper, connection, obj) + # call before_XXX extensions + for state, connection, has_identity in tups: + mapper = _state_mapper(state) + if not has_identity: + if 'before_insert' in mapper.extension.methods: + mapper.extension.before_insert(mapper, connection, state.obj()) else: - for mapper in object_mapper(obj).iterate_to_root(): - mapper.extension.before_update(mapper, connection, obj) + if 'before_update' in mapper.extension.methods: + mapper.extension.before_update(mapper, connection, state.obj()) - for obj, connection in tups: + for state, connection, has_identity in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), # and another instance with the same identity key already exists as persistent. convert to an # UPDATE if so. - mapper = object_mapper(obj) - instance_key = mapper.identity_key_from_instance(obj) - is_row_switch = not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map - if is_row_switch: - existing = uowtransaction.uow.identity_map[instance_key] + mapper = _state_mapper(state) + instance_key = mapper._identity_key_from_state(state) + if not postupdate and not has_identity and instance_key in uowtransaction.uow.identity_map: + existing = uowtransaction.uow.identity_map[instance_key]._state if not uowtransaction.is_deleted(existing): - raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.instance_str(obj), str(instance_key), mapperutil.instance_str(existing))) + raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) if self.__should_log_debug: - self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, mapperutil.instance_str(obj), mapperutil.instance_str(existing))) - uowtransaction.unregister_object(existing) - if has_identity(obj): - if obj._instance_key != instance_key: - raise exceptions.FlushError("Can't change the identity of instance %s in session (existing identity: %s; new identity: %s)" % (mapperutil.instance_str(obj), obj._instance_key, instance_key)) + self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing))) + uowtransaction.set_row_switch(existing) inserted_objects = util.Set() updated_objects = util.Set() table_to_mapper = {} - for mapper in self.base_mapper().polymorphic_iterator(): + for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: - table_to_mapper.setdefault(t, mapper) + table_to_mapper[t] = mapper - for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False): - # two lists to store parameters for each table/object pair located + for table in sqlutil.sort_tables(table_to_mapper.keys()): insert = [] update = [] - for obj, connection in tups: - mapper = object_mapper(obj) - if table not in mapper.tables or not mapper._has_pks(table): + for state, connection, has_identity in tups: + mapper = _state_mapper(state) + if table not in mapper._pks_by_table: continue - instance_key = mapper.identity_key_from_instance(obj) + pks = mapper._pks_by_table[table] + instance_key = mapper._identity_key_from_state(state) + if self.__should_log_debug: - self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.instance_str(obj), str(instance_key))) + self.__log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key))) - isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity(obj) + isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity params = {} + value_params = {} hasdata = False - for col in table.columns: - if col is mapper.version_id_col: - if not isinsert: - params[col._label] = mapper.get_attr_by_column(obj, col) - params[col.key] = params[col._label] + 1 - else: + + if isinsert: + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: params[col.key] = 1 - elif col in mapper.pks_by_table[table]: - # column is a primary key ? - if not isinsert: - # doing an UPDATE? put primary key values as "WHERE" parameters - # matching the bindparam we are creating below, i.e. "_" - params[col._label] = mapper.get_attr_by_column(obj, col) - else: - # doing an INSERT, primary key col ? - # if the primary key values are not populated, - # leave them out of the INSERT altogether, since PostGres doesn't want - # them to be present for SERIAL to take effect. A SQLEngine that uses - # explicit sequences will put them back in if they are needed - value = mapper.get_attr_by_column(obj, col) + elif col in pks: + value = mapper._get_state_attr_by_column(state, col) if value is not None: params[col.key] = value - elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): - if isinsert: + elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col): if self.__should_log_debug: self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key)) value = mapper.polymorphic_identity if col.default is None or value is not None: params[col.key] = value - else: - # column is not a primary key ? - if not isinsert: - # doing an UPDATE ? get the history for the attribute, with "passive" - # so as not to trigger any deferred loads. if there is a new - # value, add it to the bind parameters - if post_update_cols is not None and col not in post_update_cols: - continue - elif is_row_switch: - params[col.key] = self.get_attr_by_column(obj, col) - hasdata = True - continue - prop = mapper._getpropbycolumn(col, False) - if prop is None: - continue - history = prop.get_history(obj, passive=True) - if history: - a = history.added_items() - if len(a): - params[col.key] = prop.get_col_value(col, a[0]) + else: + value = mapper._get_state_attr_by_column(state, col) + if col.default is None or value is not None: + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + insert.append((state, params, mapper, connection, value_params)) + else: + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col._label] = mapper._get_state_attr_by_column(state, col) + params[col.key] = params[col._label] + 1 + for prop in mapper._columntoproperty.values(): + (added, unchanged, deleted) = attributes.get_history(state, prop.key, passive=True) + if added: hasdata = True + elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col): + pass else: - # doing an INSERT, non primary key col ? - # add the attribute's value to the - # bind parameters, unless its None and the column has a - # default. if its None and theres no default, we still might - # not want to put it in the col list but SQLIte doesnt seem to like that - # if theres no columns at all - value = mapper.get_attr_by_column(obj, col, False) - if value is NO_ATTRIBUTE: + if post_update_cols is not None and col not in post_update_cols: + if col in pks: + params[col._label] = mapper._get_state_attr_by_column(state, col) continue - if col.default is None or value is not None: - params[col.key] = value - if not isinsert: + prop = mapper._columntoproperty[col] + (added, unchanged, deleted) = attributes.get_history(state, prop.key, passive=True) + if added: + if isinstance(added[0], sql.ClauseElement): + value_params[col] = added[0] + else: + params[col.key] = prop.get_col_value(col, added[0]) + if col in pks: + if deleted: + params[col._label] = deleted[0] + else: + # row switch logic can reach us here + params[col._label] = added[0] + hasdata = True + elif col in pks: + params[col._label] = mapper._get_state_attr_by_column(state, col) if hasdata: - # if none of the attributes changed, dont even - # add the row to be updated. - update.append((obj, params, mapper, connection)) - else: - insert.append((obj, params, mapper, connection)) + update.append((state, params, mapper, connection, value_params)) - if len(update): + if update: mapper = table_to_mapper[table] clause = sql.and_() - for col in mapper.pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True)) - if mapper.version_id_col is not None: - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True)) + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) + + if mapper.version_id_col and table.c.contains_column(mapper.version_id_col): + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type)) + statement = table.update(clause) - rows = 0 - supports_sane_rowcount = True + pks = mapper._pks_by_table[table] def comparator(a, b): - for col in mapper.pks_by_table[table]: + for col in pks: x = cmp(a[1][col._label],b[1][col._label]) if x != 0: return x return 0 update.sort(comparator) + + rows = 0 for rec in update: - (obj, params, mapper, connection) = rec - c = connection.execute(statement, params) - mapper._postfetch(connection, table, obj, c, c.last_updated_params()) + (state, params, mapper, connection, value_params) = rec + c = connection.execute(statement.values(value_params), params) + mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params) - updated_objects.add((obj, connection)) + # testlib.pragma exempt:__hash__ + updated_objects.add((state, connection)) rows += c.rowcount if c.supports_sane_rowcount() and rows != len(update): raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) - if len(insert): + if insert: statement = table.insert() def comparator(a, b): - return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order) + return cmp(a[0].insert_order, b[0].insert_order) insert.sort(comparator) for rec in insert: - (obj, params, mapper, connection) = rec - c = connection.execute(statement, params) + (state, params, mapper, connection, value_params) = rec + c = connection.execute(statement.values(value_params), params) primary_key = c.last_inserted_ids() + if primary_key is not None: - i = 0 - for col in mapper.pks_by_table[table]: - if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i: - mapper.set_attr_by_column(obj, col, primary_key[i]) - i+=1 - mapper._postfetch(connection, table, obj, c, c.last_inserted_params()) + # set primary key attributes + for i, col in enumerate(mapper._pks_by_table[table]): + if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i: + mapper._set_state_attr_by_column(state, col, primary_key[i]) + mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params) # synchronize newly inserted ids from one table to the next # TODO: this fires off more than needed, try to organize syncrules # per table - def sync(mapper): - inherit = mapper.inherits - if inherit is not None: - sync(inherit) - if mapper._synchronizer is not None: - mapper._synchronizer.execute(obj, obj) - sync(mapper) - - inserted_objects.add((obj, connection)) + for m in util.reversed(list(mapper.iterate_to_root())): + if m.__inherits_equated_pairs: + m.__synchronize_inherited(state) + + # testlib.pragma exempt:__hash__ + inserted_objects.add((state, connection)) + if not postupdate: - for obj, connection in inserted_objects: - for mapper in object_mapper(obj).iterate_to_root(): - mapper.extension.after_insert(mapper, connection, obj) - for obj, connection in updated_objects: - for mapper in object_mapper(obj).iterate_to_root(): - mapper.extension.after_update(mapper, connection, obj) - - def _postfetch(self, connection, table, obj, resultproxy, params): - """After an ``INSERT`` or ``UPDATE``, ask the returned result - if ``PassiveDefaults`` fired off on the database side which - need to be post-fetched, **or** if pre-exec defaults like - ``ColumnDefaults`` were fired off and should be populated into - the instance. this is only for non-primary key columns. + # call after_XXX extensions + for state, connection, has_identity in tups: + mapper = _state_mapper(state) + if not has_identity: + if 'after_insert' in mapper.extension.methods: + mapper.extension.after_insert(mapper, connection, state.obj()) + else: + if 'after_update' in mapper.extension.methods: + mapper.extension.after_update(mapper, connection, state.obj()) + + def __synchronize_inherited(self, state): + sync.populate(state, self, state, self, self.__inherits_equated_pairs) + + def __postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params): + """After an ``INSERT`` or ``UPDATE``, assemble newly generated + values on an instance. For columns which are marked as being generated + on the database side, set up a group-based "deferred" loader + which will populate those attributes in one query when next accessed. """ - if resultproxy.lastrow_has_defaults(): - clause = sql.and_() - for p in self.pks_by_table[table]: - clause.clauses.append(p == self.get_attr_by_column(obj, p)) - row = connection.execute(table.select(clause), None).fetchone() - for c in table.c: - if self.get_attr_by_column(obj, c, False) is None: - self.set_attr_by_column(obj, c, row[c]) - else: - for c in table.c: - if c.primary_key or not c.key in params: - continue - v = self.get_attr_by_column(obj, c, False) - if v is NO_ATTRIBUTE: - continue - elif v != params.get_original(c.key): - self.set_attr_by_column(obj, c, params.get_original(c.key)) + postfetch_cols = resultproxy.postfetch_cols() + generated_cols = list(resultproxy.prefetch_cols()) - def delete_obj(self, objects, uowtransaction): + if self.polymorphic_on: + po = table.corresponding_column(self.polymorphic_on) + if po: + generated_cols.append(po) + if self.version_id_col: + generated_cols.append(self.version_id_col) + + for c in generated_cols: + if c.key in params: + self._set_state_attr_by_column(state, c, params[c.key]) + + deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]] + + if deferred_props: + if self.eager_defaults: + _instance_key = self._identity_key_from_state(state) + state.dict['_instance_key'] = _instance_key + uowtransaction.session.query(self)._get(_instance_key, refresh_instance=state, only_load_props=deferred_props) + else: + _expire_state(state, deferred_props) + + def _delete_obj(self, states, uowtransaction): """Issue ``DELETE`` statements for a list of objects. This is called within the context of a UOWTransaction during a @@ -1198,76 +1230,69 @@ class Mapper(object): """ if self.__should_log_debug: - self.__log_debug("delete_obj() start") + self.__log_debug("_delete_obj() start") if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(obj, connection_callable(self, obj)) for obj in objects] + tups = [(state, connection_callable(self, state.obj())) for state in states] else: connection = uowtransaction.transaction.connection(self) - tups = [(obj, connection) for obj in objects] + tups = [(state, connection) for state in states] + + for (state, connection) in tups: + mapper = _state_mapper(state) + if 'before_delete' in mapper.extension.methods: + mapper.extension.before_delete(mapper, connection, state.obj()) - for (obj, connection) in tups: - for mapper in object_mapper(obj).iterate_to_root(): - mapper.extension.before_delete(mapper, connection, obj) - deleted_objects = util.Set() table_to_mapper = {} - for mapper in self.base_mapper().polymorphic_iterator(): + for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: - table_to_mapper.setdefault(t, mapper) + table_to_mapper[t] = mapper - for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True): + for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True): delete = {} - for (obj, connection) in tups: - mapper = object_mapper(obj) - if table not in mapper.tables or not mapper._has_pks(table): + for (state, connection) in tups: + mapper = _state_mapper(state) + if table not in mapper._pks_by_table: continue params = {} - if not hasattr(obj, '_instance_key'): + if not _state_has_identity(state): continue else: delete.setdefault(connection, []).append(params) - for col in mapper.pks_by_table[table]: - params[col.key] = mapper.get_attr_by_column(obj, col) - if mapper.version_id_col is not None: - params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col) - deleted_objects.add((obj, connection)) + for col in mapper._pks_by_table[table]: + params[col.key] = mapper._get_state_attr_by_column(state, col) + if mapper.version_id_col and table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col) + # testlib.pragma exempt:__hash__ + deleted_objects.add((state, connection)) for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] def comparator(a, b): - for col in mapper.pks_by_table[table]: + for col in mapper._pks_by_table[table]: x = cmp(a[col.key],b[col.key]) if x != 0: return x return 0 del_objects.sort(comparator) clause = sql.and_() - for col in mapper.pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True)) - if mapper.version_id_col is not None: - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True)) + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col.key, type_=col.type)) + if mapper.version_id_col and table.c.contains_column(mapper.version_id_col): + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type)) statement = table.delete(clause) c = connection.execute(statement, del_objects) - if c.supports_sane_rowcount() and c.rowcount != len(del_objects): - raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete))) + if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): + raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) - for obj, connection in deleted_objects: - for mapper in object_mapper(obj).iterate_to_root(): - mapper.extension.after_delete(mapper, connection, obj) + for state, connection in deleted_objects: + mapper = _state_mapper(state) + if 'after_delete' in mapper.extension.methods: + mapper.extension.after_delete(mapper, connection, state.obj()) - def _has_pks(self, table): - try: - for k in self.pks_by_table[table]: - if not self.columntoproperty.has_key(k): - return False - else: - return True - except KeyError: - return False - - def register_dependencies(self, uowcommit, *args, **kwargs): + def _register_dependencies(self, uowcommit): """Register ``DependencyProcessor`` instances with a ``unitofwork.UOWTransaction``. @@ -1276,336 +1301,327 @@ class Mapper(object): """ for prop in self.__props.values(): - prop.register_dependencies(uowcommit, *args, **kwargs) + prop.register_dependencies(uowcommit) + for dep in self._dependency_processors: + dep.register_dependencies(uowcommit) - def cascade_iterator(self, type, object, recursive=None, halt_on=None): - """Iterate each element in an object graph, for all relations - taht meet the given cascade rule. + def cascade_iterator(self, type_, state, halt_on=None): + """Iterate each element and its mapper in an object graph, + for all relations that meet the given cascade rule. - type + type\_ The name of the cascade rule (i.e. save-update, delete, etc.) - object - The lead object instance. child items will be processed per + state + The lead InstanceState. child items will be processed per the relations defined for this object's mapper. - recursive - Used by the function for internal context during recursive - calls, leave as None. + the return value are object instances; this provides a strong + reference so that they don't fall out of scope immediately. """ - if recursive is None: - recursive=util.Set() - for prop in self.__props.values(): - for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on): - yield c + visited_instances = util.IdentitySet() + visitables = [(self.__props.itervalues(), 'property', state)] - def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None): - """Execute a callable for each element in an object graph, for - all relations that meet the given cascade rule. - - type - The name of the cascade rule (i.e. save-update, delete, etc.) - - object - The lead object instance. child items will be processed per - the relations defined for this object's mapper. - - callable\_ - The callable function. - - recursive - Used by the function for internal context during recursive - calls, leave as None. - - """ - - if recursive is None: - recursive=util.Set() - for prop in self.__props.values(): - prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on) - - def get_select_mapper(self): - """Return the mapper used for issuing selects. - - This mapper is the same mapper as `self` unless the - select_table argument was specified for this mapper. - """ - - return self.__surrogate_mapper or self - - def _instance(self, context, row, result = None, skip_polymorphic=False): - """Pull an object instance from the given row and append it to - the given result list. - - If the instance already exists in the given identity map, its - not added. In either case, execute all the property loaders - on the instance to also process extra information in the row. - """ - - # apply ExtensionOptions applied to the Query to this mapper, - # but only if our mapper matches. - # TODO: what if our mapper inherits from the mapper (i.e. as in a polymorphic load?) - if context.mapper is self: - extension = context.extension - else: + while visitables: + iterator,item_type,parent_state = visitables[-1] + try: + if item_type == 'property': + prop = iterator.next() + visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None)) + elif item_type == 'mapper': + instance, instance_mapper, corresponding_state = iterator.next() + yield (instance, instance_mapper) + visitables.append((instance_mapper.__props.itervalues(), 'property', corresponding_state)) + except StopIteration: + visitables.pop() + + def _instance(self, context, row, result=None, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None): + if not extension: extension = self.extension - ret = extension.translate_row(self, context, row) - if ret is not EXT_PASS: - row = ret - - if not skip_polymorphic and self.polymorphic_on is not None: + if 'translate_row' in extension.methods: + ret = extension.translate_row(self, context, row) + if ret is not EXT_CONTINUE: + row = ret + + if polymorphic_from: + # if we are called from a base mapper doing a polymorphic load, figure out what tables, + # if any, will need to be "post-fetched" based on the tables present in the row, + # or from the options set up on the query + if ('polymorphic_fetch', self) not in context.attributes: + if self in context.query._with_polymorphic: + context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, []) + else: + context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [t for t in self.tables if t not in polymorphic_from.tables]) + + elif not refresh_instance and self.polymorphic_on: discriminator = row[self.polymorphic_on] if discriminator is not None: - mapper = self.polymorphic_map[discriminator] + try: + mapper = self.polymorphic_map[discriminator] + except KeyError: + raise exceptions.AssertionError("No such polymorphic_identity %r is defined" % discriminator) if mapper is not self: - if ('polymorphic_fetch', mapper) not in context.attributes: - context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables]) - row = self.translate_row(mapper, row) - return mapper._instance(context, row, result=result, skip_polymorphic=True) - - # look in main identity map. if its there, we dont do anything to it, - # including modifying any of its related items lists, as its already - # been exposed to being modified by the application. - - identitykey = self.identity_key_from_row(row) - populate_existing = context.populate_existing or self.always_refresh - if context.session.has_key(identitykey): - instance = context.session._get(identitykey) + return mapper._instance(context, row, result=result, polymorphic_from=self) + + # determine identity key + if refresh_instance: + try: + identitykey = refresh_instance.dict['_instance_key'] + except KeyError: + # super-rare condition; a refresh is being called + # on a non-instance-key instance; this is meant to only + # occur wihtin a flush() + identitykey = self._identity_key_from_state(refresh_instance) + else: + identitykey = self.identity_key_from_row(row) + + session_identity_map = context.session.identity_map + + if identitykey in session_identity_map: + instance = session_identity_map[identitykey] + state = instance._state + if self.__should_log_debug: - self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) - isnew = False - if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: - raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) - - if populate_existing or context.session.is_expired(instance, unexpire=True): - if not context.identity_map.has_key(identitykey): - context.identity_map[identitykey] = instance - isnew = True - if extension.populate_instance(self, context, row, instance, **{'instancekey':identitykey, 'isnew':isnew}) is EXT_PASS: - self.populate_instance(context, instance, row, **{'instancekey':identitykey, 'isnew':isnew}) - if extension.append_result(self, context, row, instance, result, **{'instancekey':identitykey, 'isnew':isnew}) is EXT_PASS: - if result is not None: - result.append(instance) - return instance + self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey))) + + isnew = state.runid != context.runid + currentload = not isnew + + if not currentload and context.version_check and self.version_id_col and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: + raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) + elif refresh_instance: + # out of band refresh_instance detected (i.e. its not in the session.identity_map) + # honor it anyway. this can happen if a _get() occurs within save_obj(), such as + # when eager_defaults is True. + state = refresh_instance + instance = state.obj() + isnew = state.runid != context.runid + currentload = True else: if self.__should_log_debug: self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) - # look in result-local identitymap for it. - exists = identitykey in context.identity_map - if not exists: + if self.allow_null_pks: - # check if *all* primary key cols in the result are None - this indicates - # an instance of the object is not present in the row. for x in identitykey[1]: if x is not None: break else: return None else: - # otherwise, check if *any* primary key cols in the result are None - this indicates - # an instance of the object is not present in the row. if None in identitykey[1]: return None - - # plugin point - instance = extension.create_instance(self, context, row, self.class_) - if instance is EXT_PASS: - instance = self._create_instance(context.session) - else: - instance._entity_name = self.entity_name - if self.__should_log_debug: - self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) - context.identity_map[identitykey] = instance isnew = True - else: - instance = context.identity_map[identitykey] - isnew = False - - # call further mapper properties on the row, to pull further - # instances from the row and possibly populate this item. - flags = {'instancekey':identitykey, 'isnew':isnew} - if extension.populate_instance(self, context, row, instance, **flags) is EXT_PASS: - self.populate_instance(context, instance, row, **flags) - if extension.append_result(self, context, row, instance, result, **flags) is EXT_PASS: - if result is not None: - result.append(instance) - return instance + currentload = True - def _create_instance(self, session): - obj = self.class_.__new__(self.class_) - obj._entity_name = self.entity_name - - # this gets the AttributeManager to do some pre-initialization, - # in order to save on KeyErrors later on - attribute_manager.init_attr(obj) + if 'create_instance' in extension.methods: + instance = extension.create_instance(self, context, row, self.class_) + if instance is EXT_CONTINUE: + instance = attributes.new_instance(self.class_) + else: + attributes.manage(instance) + else: + instance = attributes.new_instance(self.class_) - return obj + if self.__should_log_debug: + self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) - def _deferred_inheritance_condition(self, needs_tables): - cond = self.inherit_condition + state = instance._state + instance._entity_name = self.entity_name + instance._instance_key = identitykey + instance._sa_session_id = context.session.hash_key + session_identity_map[identitykey] = instance - param_names = [] - def visit_binary(binary): - leftcol = binary.left - rightcol = binary.right - if leftcol is None or rightcol is None: - return - if leftcol.table not in needs_tables: - binary.left = sql.bindparam(leftcol.name, None, type_=binary.right.type, unique=True) - param_names.append(leftcol) - elif rightcol not in needs_tables: - binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True) - param_names.append(rightcol) - cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True) - return cond, param_names + if currentload or context.populate_existing or self.always_refresh: + if isnew: + state.runid = context.runid + context.progress.add(state) - def translate_row(self, tomapper, row): - """Translate the column keys of a row into a new or proxied - row that can be understood by another mapper. + if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) + + else: + # populate attributes on non-loading instances which have been expired + # TODO: also support deferred attributes here [ticket:870] + if state.expired_attributes: + if state in context.partials: + isnew = False + attrs = context.partials[state] + else: + isnew = True + attrs = state.expired_attributes.intersection(state.unmodified) + context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs - This can be used in conjunction with populate_instance to - populate an instance using an alternate mapper. - """ + if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew) + + if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE): + result.append(instance) - newrow = util.DictDecorator(row) - for c in tomapper.mapped_table.c: - c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=False) - if c2 and row.has_key(c2): - newrow[c] = row[c2] - return newrow + return instance - def populate_instance(self, selectcontext, instance, row, ispostselect=None, **flags): + def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags): """populate an instance from a result row.""" - selectcontext.stack.push_mapper(self) - populators = selectcontext.attributes.get(('instance_populators', self, selectcontext.stack.snapshot(), ispostselect), None) - if populators is None: - populators = [] + snapshot = selectcontext.path + (self,) + # retrieve a set of "row population" functions derived from the MapperProperties attached + # to this Mapper. These are keyed in the select context based primarily off the + # "snapshot" of the stack, which represents a path from the lead mapper in the query to this one, + # including relation() names. the key also includes "self", and allows us to distinguish between + # other mappers within our inheritance hierarchy + (new_populators, existing_populators) = selectcontext.attributes.get(('populators', self, snapshot, ispostselect), (None, None)) + if new_populators is None: + # no populators; therefore this is the first time we are receiving a row for + # this result set. issue create_row_processor() on all MapperProperty objects + # and cache in the select context. + new_populators = [] + existing_populators = [] post_processors = [] for prop in self.__props.values(): - (pop, post_proc) = prop.create_row_processor(selectcontext, self, row) - if pop is not None: - populators.append(pop) - if post_proc is not None: + (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row) + if newpop: + new_populators.append((prop.key, newpop)) + if existingpop: + existing_populators.append((prop.key, existingpop)) + if post_proc: post_processors.append(post_proc) - + + # install a post processor for immediate post-load of joined-table inheriting mappers poly_select_loader = self._get_poly_select_loader(selectcontext, row) - if poly_select_loader is not None: + if poly_select_loader: post_processors.append(poly_select_loader) - - selectcontext.attributes[('instance_populators', self, selectcontext.stack.snapshot(), ispostselect)] = populators + + selectcontext.attributes[('populators', self, snapshot, ispostselect)] = (new_populators, existing_populators) selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors - for p in populators: - p(instance, row, ispostselect=ispostselect, **flags) - - selectcontext.stack.pop() - + if isnew or ispostselect: + populators = new_populators + else: + populators = existing_populators + + if only_load_props: + populators = [p for p in populators if p[0] in only_load_props] + + for (key, populator) in populators: + selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags) + if self.non_primary: - selectcontext.attributes[('populating_mapper', instance)] = self - - def _post_instance(self, selectcontext, instance): + selectcontext.attributes[('populating_mapper', instance._state)] = self + + def _post_instance(self, selectcontext, state, **kwargs): post_processors = selectcontext.attributes[('post_processors', self, None)] for p in post_processors: - p(instance) + p(state.obj(), **kwargs) def _get_poly_select_loader(self, selectcontext, row): - # 'select' or 'union'+col not present - (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None)) - if hosted_mapper is None or len(needs_tables)==0 or hosted_mapper.polymorphic_fetch == 'deferred': - return - - cond, param_names = self._deferred_inheritance_condition(needs_tables) - statement = sql.select(needs_tables, cond, use_labels=True) - def post_execute(instance, **flags): - self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance)) - - identitykey = self.identity_key_from_instance(instance) + """set up attribute loaders for 'select' and 'deferred' polymorphic loading. - params = {} - for c in param_names: - params[c.name] = self.get_attr_by_column(instance, c) - row = selectcontext.session.connection(self).execute(statement, **params).fetchone() - self.populate_instance(selectcontext, instance, row, **{'isnew':False, 'instancekey':identitykey, 'ispostselect':True}) - - return post_execute - -Mapper.logger = logging.class_logger(Mapper) + this loading uses a second SELECT statement to load additional tables, + either immediately after loading the main table or via a deferred attribute trigger. + """ + (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None)) + if hosted_mapper is None or not needs_tables: + return + cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables) + statement = sql.select(needs_tables, cond, use_labels=True) -class ClassKey(object): - """Key a class and an entity name to a mapper, via the mapper_registry.""" + if hosted_mapper.polymorphic_fetch == 'select': + def post_execute(instance, **flags): + if self.__should_log_debug: + self.__log_debug("Post query loading instance " + instance_str(instance)) - __metaclass__ = util.ArgSingleton + identitykey = self.identity_key_from_instance(instance) + + only_load_props = flags.get('only_load_props', None) - def __init__(self, class_, entity_name): - self.class_ = class_ - self.entity_name = entity_name + params = {} + for c, bind in param_names: + params[bind] = self._get_attr_by_column(instance, c) + row = selectcontext.session.connection(self).execute(statement, params).fetchone() + self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True, only_load_props=only_load_props) + return post_execute + elif hosted_mapper.polymorphic_fetch == 'deferred': + from sqlalchemy.orm.strategies import DeferredColumnLoader + + def post_execute(instance, **flags): + def create_statement(instance): + params = {} + for (c, bind) in param_names: + # use the "committed" (database) version to get query column values + params[bind] = self._get_committed_attr_by_column(instance, c) + return (statement, params) + + props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__] + keys = [p.key for p in props] + + only_load_props = flags.get('only_load_props', None) + if only_load_props: + keys = util.Set(keys).difference(only_load_props) + props = [p for p in props if p.key in only_load_props] + + for prop in props: + strategy = prop._get_strategy(DeferredColumnLoader) + instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement)) + return post_execute + else: + return None - def __hash__(self): - return hash((self.class_, self.entity_name)) + def _deferred_inheritance_condition(self, base_mapper, needs_tables): + base_mapper = base_mapper.primary_mapper() - def __eq__(self, other): - return self is other + def visit_binary(binary): + leftcol = binary.left + rightcol = binary.right + if leftcol is None or rightcol is None: + return + if leftcol.table not in needs_tables: + binary.left = sql.bindparam(None, None, type_=binary.right.type) + param_names.append((leftcol, binary.left)) + elif rightcol not in needs_tables: + binary.right = sql.bindparam(None, None, type_=binary.right.type) + param_names.append((rightcol, binary.right)) - def __repr__(self): - return "ClassKey(%s, %s)" % (repr(self.class_), repr(self.entity_name)) + allconds = [] + param_names = [] - def dispose(self): - type(self).dispose_static(self.class_, self.entity_name) + for mapper in self.iterate_to_root(): + if mapper is base_mapper: + break + allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary)) -def has_identity(object): - return hasattr(object, '_instance_key') + return sql.and_(*allconds), param_names -def has_mapper(object): - """Return True if the given object has had a mapper association - set up, either through loading, or via insertion in a session. - """ +Mapper.logger = logging.class_logger(Mapper) - return hasattr(object, '_entity_name') -def object_mapper(object, entity_name=None, raiseerror=True): - """Given an object, return the primary Mapper associated with the object instance. - - object - The object instance. - - entity_name - Entity name of the mapper to retrieve, if the given instance is - transient. Otherwise uses the entity name already associated - with the instance. - - raiseerror - Defaults to True: raise an ``InvalidRequestError`` if no mapper can - be located. If False, return None. - - """ +object_session = None - try: - mapper = mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', entity_name))] - except (KeyError, AttributeError): - if raiseerror: - raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', entity_name))) - else: - return None - return mapper.compile() +def _load_scalar_attributes(instance, attribute_names): + mapper = object_mapper(instance) + global object_session + if not object_session: + from sqlalchemy.orm.session import object_session + session = object_session(instance) + if not session: + try: + session = mapper.get_session() + except exceptions.InvalidRequestError: + raise exceptions.UnboundExecutionError("Instance %s is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed" % (instance.__class__)) + + state = instance._state + if '_instance_key' in state.dict: + identity_key = state.dict['_instance_key'] + shouldraise = True + else: + # if instance is pending, a refresh operation may not complete (even if PK attributes are assigned) + shouldraise = False + identity_key = mapper._identity_key_from_state(state) -def class_mapper(class_, entity_name=None, compile=True): - """Given a class and optional entity_name, return the primary Mapper associated with the key. - - If no mapper can be located, raises ``InvalidRequestError``. - """ + if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None and shouldraise: + raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance)) - try: - mapper = mapper_registry[ClassKey(class_, entity_name)] - except (KeyError, AttributeError): - raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name)) - if compile: - return mapper.compile() - else: - return mapper diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 6ce9fd7069..33a0ff4326 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -1,25 +1,28 @@ # properties.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Defines a set of mapper.MapperProperty objects, including basic -column properties as well as relationships. The objects rely upon the -LoaderStrategy objects in the strategies.py module to handle load -operations. PropertyLoader also relies upon the dependency.py module -to handle flush-time dependency sorting and processing. +"""MapperProperty implementations. + +This is a private module which defines the behavior of +invidual ORM-mapped attributes. """ -from sqlalchemy import sql, schema, util, exceptions, sql_util, logging -from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency +from sqlalchemy import sql, schema, util, exceptions, logging +from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, find_columns +from sqlalchemy.sql import visitors, operators, ColumnElement +from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper from sqlalchemy.orm import session as sessionlib -from sqlalchemy.orm import util as mapperutil -import operator -from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator +from sqlalchemy.orm.mapper import _class_to_mapper +from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses +from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY from sqlalchemy.exceptions import ArgumentError -__all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef'] +__all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty', + 'ComparableProperty', 'PropertyLoader', 'BackRef') + class ColumnProperty(StrategizedProperty): """Describes an object attribute that corresponds to a table column.""" @@ -35,33 +38,44 @@ class ColumnProperty(StrategizedProperty): self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) self.comparator = ColumnProperty.ColumnComparator(self) - # sanity check - for col in columns: - if not hasattr(col, 'name'): - if hasattr(col, 'label'): - raise ArgumentError('ColumnProperties must be named for the mapper to work with them. Try .label() to fix this') - raise ArgumentError('%r is not a valid candidate for ColumnProperty' % col) - - def create_strategy(self): if self.deferred: - return strategies.DeferredColumnLoader(self) + self.strategy_class = strategies.DeferredColumnLoader else: - return strategies.ColumnLoader(self) - + self.strategy_class = strategies.ColumnLoader + # sanity check + for col in columns: + if not isinstance(col, ColumnElement): + raise ArgumentError('column_property() must be given a ColumnElement as its argument. Try .label() or .as_scalar() for Selectables to fix this.') + + def do_init(self): + super(ColumnProperty, self).do_init() + if len(self.columns) > 1 and self.parent.primary_key.issuperset(self.columns): + util.warn( + ("On mapper %s, primary key column '%s' is being combined " + "with distinct primary key column '%s' in attribute '%s'. " + "Use explicit properties to give each column its own mapped " + "attribute name.") % (str(self.parent), str(self.columns[1]), + str(self.columns[0]), self.key)) + def copy(self): return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) - - def getattr(self, object, column): - return getattr(object, self.key) - def setattr(self, object, value, column): - setattr(object, self.key, value) + def getattr(self, state, column): + return getattr(state.class_, self.key).impl.get(state) - def get_history(self, obj, passive=False): - return sessionlib.attribute_manager.get_history(obj, self.key, passive=passive) + def getcommitted(self, state, column): + return getattr(state.class_, self.key).impl.get_committed_value(state) - def merge(self, session, source, dest, _recursive): - setattr(dest, self.key, getattr(source, self.key, None)) + def setattr(self, state, value, column): + getattr(state.class_, self.key).impl.set(state, value, None) + + def merge(self, session, source, dest, dont_load, _recursive): + value = attributes.get_as_list(source._state, self.key, passive=True) + if value: + setattr(dest, self.key, value[0]) + else: + # TODO: lazy callable should merge to the new instance + dest._state.expire_attributes([self.key]) def get_col_value(self, column, value): return value @@ -69,44 +83,52 @@ class ColumnProperty(StrategizedProperty): class ColumnComparator(PropComparator): def clause_element(self): return self.prop.columns[0] - - def operate(self, op, other): - return op(self.prop.columns[0], other) - def reverse_operate(self, op, other): + def operate(self, op, *other, **kwargs): + return op(self.prop.columns[0], *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): col = self.prop.columns[0] - return op(col._bind_param(other), col) + return op(col._bind_param(other), col, **kwargs) - ColumnProperty.logger = logging.class_logger(ColumnProperty) -mapper.ColumnProperty = ColumnProperty - class CompositeProperty(ColumnProperty): """subclasses ColumnProperty to provide composite type support.""" - + def __init__(self, class_, *columns, **kwargs): super(CompositeProperty, self).__init__(*columns, **kwargs) self.composite_class = class_ - self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator(self)) - + self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self) + + def do_init(self): + super(ColumnProperty, self).do_init() + # TODO: similar PK check as ColumnProperty does ? + def copy(self): return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) - def getattr(self, object, column): - obj = getattr(object, self.key) + def getattr(self, state, column): + obj = getattr(state.class_, self.key).impl.get(state) + return self.get_col_value(column, obj) + + def getcommitted(self, state, column): + obj = getattr(state.class_, self.key).impl.get_committed_value(state) return self.get_col_value(column, obj) - def setattr(self, object, value, column): - obj = getattr(object, self.key, None) + def setattr(self, state, value, column): + # TODO: test coverage for this method + obj = getattr(state.class_, self.key).impl.get(state) if obj is None: obj = self.composite_class(*[None for c in self.columns]) - for a, b in zip(self.columns, value.__colset__()): + getattr(state.class_, self.key).impl.set(state, obj, None) + + for a, b in zip(self.columns, value.__composite_values__()): if a is column: setattr(obj, b, value) def get_col_value(self, column, value): - for a, b in zip(self.columns, value.__colset__()): + for a, b in zip(self.columns, value.__composite_values__()): if a is column: return b @@ -115,17 +137,80 @@ class CompositeProperty(ColumnProperty): if other is None: return sql.and_(*[a==None for a in self.prop.columns]) else: - return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())]) + return sql.and_(*[a==b for a, b in + zip(self.prop.columns, + other.__composite_values__())]) def __ne__(self, other): - return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())]) + return sql.or_(*[a!=b for a, b in + zip(self.prop.columns, + other.__composite_values__())]) + +class SynonymProperty(MapperProperty): + def __init__(self, name, map_column=None, descriptor=None): + self.name = name + self.map_column=map_column + self.descriptor = descriptor + + def setup(self, querycontext, **kwargs): + pass + + def create_row_processor(self, selectcontext, mapper, row): + return (None, None, None) + + def do_init(self): + class_ = self.parent.class_ + def comparator(): + return self.parent._get_property(self.key, resolve_synonyms=True).comparator + self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__)) + if self.descriptor is None: + class SynonymProp(object): + def __set__(s, obj, value): + setattr(obj, self.name, value) + def __delete__(s, obj): + delattr(obj, self.name) + def __get__(s, obj, owner): + if obj is None: + return s + return getattr(obj, self.name) + self.descriptor = SynonymProp() + sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator) + + def merge(self, session, source, dest, _recursive): + pass +SynonymProperty.logger = logging.class_logger(SynonymProperty) + + +class ComparableProperty(MapperProperty): + """Instruments a Python property for use in query expressions.""" + + def __init__(self, comparator_factory, descriptor=None): + self.descriptor = descriptor + self.comparator = comparator_factory(self) + + def do_init(self): + """Set up a proxy to the unmanaged descriptor.""" + + class_ = self.parent.class_ + # refactor me + sessionlib.register_attribute(class_, self.key, uselist=False, + proxy_property=self.descriptor, + useobject=False, + comparator=self.comparator) + + def setup(self, querycontext, **kwargs): + pass + + def create_row_processor(self, selectcontext, mapper, row): + return (None, None, None) + class PropertyLoader(StrategizedProperty): """Describes an object property that holds a single item or list of items that correspond to a related database table. """ - def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None): + def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, passive_updates=True, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None, _local_remote_pairs=None): self.uselist = uselist self.argument = argument self.entity_name = entity_name @@ -138,23 +223,46 @@ class PropertyLoader(StrategizedProperty): self.lazy = lazy self.foreign_keys = util.to_set(foreign_keys) self._legacy_foreignkey = util.to_set(foreignkey) + if foreignkey: + util.warn_deprecated('foreignkey option is deprecated; see docs for details') self.collection_class = collection_class self.passive_deletes = passive_deletes + self.passive_updates = passive_updates self.remote_side = util.to_set(remote_side) self.enable_typechecks = enable_typechecks - self._parent_join_cache = {} self.comparator = PropertyLoader.Comparator(self) self.join_depth = join_depth + self._arg_local_remote_pairs = _local_remote_pairs + + if strategy_class: + self.strategy_class = strategy_class + elif self.lazy == 'dynamic': + from sqlalchemy.orm import dynamic + self.strategy_class = dynamic.DynaLoader + elif self.lazy is False: + self.strategy_class = strategies.EagerLoader + elif self.lazy is None: + self.strategy_class = strategies.NoLoader + else: + self.strategy_class = strategies.LazyLoader + + self._reverse_property = None if cascade is not None: - self.cascade = mapperutil.CascadeOptions(cascade) + self.cascade = CascadeOptions(cascade) else: if private: - self.cascade = mapperutil.CascadeOptions("all, delete-orphan") + util.warn_deprecated('private option is deprecated; see docs for details') + self.cascade = CascadeOptions("all, delete-orphan") else: - self.cascade = mapperutil.CascadeOptions("save-update, merge") + self.cascade = CascadeOptions("save-update, merge") + + if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade): + raise exceptions.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade") self.association = association + if association: + util.warn_deprecated('association option is deprecated; see docs for details') self.order_by = order_by self.attributeext=attributeext if isinstance(backref, str): @@ -162,20 +270,31 @@ class PropertyLoader(StrategizedProperty): # just a string was sent if secondary is not None: # reverse primary/secondary in case of a many-to-many - self.backref = BackRef(backref, primaryjoin=secondaryjoin, secondaryjoin=primaryjoin) + self.backref = BackRef(backref, primaryjoin=secondaryjoin, secondaryjoin=primaryjoin, passive_updates=self.passive_updates) else: - self.backref = BackRef(backref, primaryjoin=primaryjoin, secondaryjoin=secondaryjoin) + self.backref = BackRef(backref, primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, passive_updates=self.passive_updates) else: self.backref = backref self.is_backref = is_backref class Comparator(PropComparator): + def __init__(self, prop, of_type=None): + self.prop = self.property = prop + if of_type: + self._of_type = _class_to_mapper(of_type) + + def of_type(self, cls): + return PropertyLoader.Comparator(self.prop, cls) + def __eq__(self, other): if other is None: - return ~sql.exists([1], self.prop.primaryjoin) + if self.prop.direction in [ONETOMANY, MANYTOMANY]: + return ~sql.exists([1], self.prop.primaryjoin) + else: + return self.prop._optimized_compare(None) elif self.prop.uselist: if not hasattr(other, '__iter__'): - raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.") + raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object. Use contains().") else: j = self.prop.primaryjoin if self.prop.secondaryjoin: @@ -186,200 +305,216 @@ class PropertyLoader(StrategizedProperty): sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))])) ) return sql.and_(*clauses) - else: + else: return self.prop._optimized_compare(other) - - def any(self, criterion=None, **kwargs): - if not self.prop.uselist: - raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") - j = self.prop.primaryjoin - if self.prop.secondaryjoin: - j = j & self.prop.secondaryjoin + + def _join_and_criterion(self, criterion=None, **kwargs): + if getattr(self, '_of_type', None): + target_mapper = self._of_type + to_selectable = target_mapper._with_polymorphic_selectable() #mapped_table + else: + to_selectable = None + + pj, sj, source, dest, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable) + for k in kwargs: crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) if criterion is None: criterion = crit else: criterion = criterion & crit - return sql.exists([1], j & criterion) - + + if sj: + j = pj & sj + else: + j = pj + + if criterion and target_adapter: + criterion = target_adapter.traverse(criterion) + + return j, criterion, dest + + def any(self, criterion=None, **kwargs): + if not self.prop.uselist: + raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") + j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs) + + return sql.exists([1], j & criterion, from_obj=from_obj) + def has(self, criterion=None, **kwargs): if self.prop.uselist: raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().") - j = self.prop.primaryjoin - if self.prop.secondaryjoin: - j = j & self.prop.secondaryjoin - for k in kwargs: - crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) - if criterion is None: - criterion = crit - else: - criterion = criterion & crit - return sql.exists([1], j & criterion) - + j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs) + + return sql.exists([1], j & criterion, from_obj=from_obj) + def contains(self, other): if not self.prop.uselist: raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") clause = self.prop._optimized_compare(other) - j = self.prop.primaryjoin if self.prop.secondaryjoin: - j = j & self.prop.secondaryjoin + clause.negation_clause = self._negated_contains_or_equals(other) - clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) return clause + def _negated_contains_or_equals(self, other): + criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]) + j, criterion, from_obj = self._join_and_criterion(criterion) + return ~sql.exists([1], j & criterion, from_obj=from_obj) + def __ne__(self, other): + if other is None: + if self.prop.direction == MANYTOONE: + return sql.or_(*[x!=None for x in self.prop.foreign_keys]) + elif self.prop.uselist: + return self.any() + else: + return self.has() + if self.prop.uselist and not hasattr(other, '__iter__'): raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") - - j = self.prop.primaryjoin - if self.prop.secondaryjoin: - j = j & self.prop.secondaryjoin - return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) - + + return self._negated_contains_or_equals(other) + def compare(self, op, value, value_is_parent=False): - if op == operator.eq: + if op == operators.eq: if value is None: - return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin) + if self.uselist: + return ~sql.exists([1], self.primaryjoin) + else: + return self._optimized_compare(None, value_is_parent=value_is_parent) else: return self._optimized_compare(value, value_is_parent=value_is_parent) else: return op(self.comparator, value) - + def _optimized_compare(self, value, value_is_parent=False): - # optimized operation for ==, uses a lazy clause. - (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent) - bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) - - class Visitor(sql.ClauseVisitor): - def visit_bindparam(s, bindparam): - mapper = value_is_parent and self.parent or self.mapper - bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key]) - Visitor().traverse(criterion) - return criterion - - private = property(lambda s:s.cascade.delete_orphan) + return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent) - def create_strategy(self): - if self.lazy: - return strategies.LazyLoader(self) - elif self.lazy is False: - return strategies.EagerLoader(self) - elif self.lazy is None: - return strategies.NoLoader(self) + def private(self): + return self.cascade.delete_orphan + private = property(private) def __str__(self): return str(self.parent.class_.__name__) + "." + self.key + " (" + str(self.mapper.class_.__name__) + ")" - def merge(self, session, source, dest, _recursive): - if not "merge" in self.cascade or self.mapper in _recursive: + def merge(self, session, source, dest, dont_load, _recursive): + if not dont_load and self._reverse_property and (source, self._reverse_property) in _recursive: + return + + if not "merge" in self.cascade: + dest._state.expire_attributes([self.key]) return - childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True) - if childlist is None: + + instances = attributes.get_as_list(source._state, self.key, passive=True) + if not instances: return + if self.uselist: - # sets a blank collection according to the correct list class - dest_list = sessionlib.attribute_manager.init_collection(dest, self.key) - for current in list(childlist): - obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive) + dest_list = [] + for current in instances: + _recursive[(current, self)] = True + obj = session.merge(current, entity_name=self.mapper.entity_name, dont_load=dont_load, _recursive=_recursive) if obj is not None: - #dest_list.append_without_event(obj) - dest_list.append_with_event(obj) + dest_list.append(obj) + if dont_load: + coll = attributes.init_collection(dest, self.key) + for c in dest_list: + coll.append_without_event(c) + else: + getattr(dest.__class__, self.key).impl._set_iterable(dest._state, dest_list) else: - current = list(childlist)[0] + current = instances[0] if current is not None: - obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive) + _recursive[(current, self)] = True + obj = session.merge(current, entity_name=self.mapper.entity_name, dont_load=dont_load, _recursive=_recursive) if obj is not None: - setattr(dest, self.key, obj) + if dont_load: + dest.__dict__[self.key] = obj + else: + setattr(dest, self.key, obj) - def cascade_iterator(self, type, object, recursive, halt_on=None): - if not type in self.cascade: - return - passive = type != 'delete' or self.passive_deletes - mapper = self.mapper.primary_mapper() - for c in sessionlib.attribute_manager.get_as_list(object, self.key, passive=passive): - if c is not None and c not in recursive and (halt_on is None or not halt_on(c)): - if not isinstance(c, self.mapper.class_): - raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) - recursive.add(c) - yield c - for c2 in mapper.cascade_iterator(type, c, recursive): - yield c2 - - def cascade_callable(self, type, object, callable_, recursive, halt_on=None): - if not type in self.cascade: + def cascade_iterator(self, type_, state, visited_instances, halt_on=None): + if not type_ in self.cascade: return - + passive = type_ != 'delete' or self.passive_deletes mapper = self.mapper.primary_mapper() - passive = type != 'delete' or self.passive_deletes - for c in sessionlib.attribute_manager.get_as_list(object, self.key, passive=passive): - if c is not None and c not in recursive and (halt_on is None or not halt_on(c)): - if not isinstance(c, self.mapper.class_): - raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) - recursive.add(c) - callable_(c, mapper.entity_name) - mapper.cascade_callable(type, c, callable_, recursive) + instances = attributes.get_as_list(state, self.key, passive=passive) + if instances: + for c in instances: + if c is not None and c not in visited_instances and (halt_on is None or not halt_on(c)): + if not isinstance(c, self.mapper.class_): + raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) + visited_instances.add(c) + + # cascade using the mapper local to this object, so that its individual properties are located + instance_mapper = object_mapper(c, entity_name=mapper.entity_name) + yield (c, instance_mapper, c._state) def _get_target_class(self): """Return the target class of the relation, even if the property has not been initialized yet. - """ + """ if isinstance(self.argument, type): return self.argument else: return self.argument.class_ def do_init(self): - self._determine_targets() - self._determine_joins() - self._determine_fks() - self._determine_direction() - self._determine_remote_side() - self._create_polymorphic_joins() + self.__determine_targets() + self.__determine_joins() + self.__determine_fks() + self.__determine_direction() + self.__determine_remote_side() self._post_init() - def _determine_targets(self): + def __determine_targets(self): if isinstance(self.argument, type): - self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)._check_compile() + self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False) elif isinstance(self.argument, mapper.Mapper): - self.mapper = self.argument._check_compile() + self.mapper = self.argument + elif callable(self.argument): + # accept a callable to suit various deferred-configurational schemes + self.mapper = mapper.class_mapper(self.argument(), entity_name=self.entity_name, compile=False) else: raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument))) - # ensure the "select_mapper", if different from the regular target mapper, is compiled. - self.mapper.get_select_mapper()._check_compile() - - if self.association is not None: - if isinstance(self.association, type): - self.association = mapper.class_mapper(self.association, entity_name=self.entity_name, compile=False)._check_compile() + if not self.parent.concrete: + for inheriting in self.parent.iterate_to_root(): + if inheriting is not self.parent and inheriting._get_property(self.key, raiseerr=False): + util.warn( + ("Warning: relation '%s' on mapper '%s' supercedes " + "the same relation on inherited mapper '%s'; this " + "can cause dependency issues during flush") % + (self.key, self.parent, inheriting)) self.target = self.mapper.mapped_table - self.select_mapper = self.mapper.get_select_mapper() - self.select_table = self.mapper.select_table - self.loads_polymorphic = self.target is not self.select_table + self.table = self.mapper.mapped_table if self.cascade.delete_orphan: if self.parent.class_ is self.mapper.class_: - raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade rule on a self-referential relationship. You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self))) + raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade " + "rule on a self-referential relationship. " + "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self))) self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_)) - def _determine_joins(self): + def __determine_joins(self): if self.secondaryjoin is not None and self.secondary is None: raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument") # if join conditions were not specified, figure them out based on foreign keys - + def _search_for_join(mapper, table): """find a join between the given mapper's mapped table and the given table. - will try the mapper's local table first for more specificity, then if not + will try the mapper's local table first for more specificity, then if not found will try the more general mapped table, which in the case of inheritance is a join.""" try: return sql.join(mapper.local_table, table) except exceptions.ArgumentError, e: return sql.join(mapper.mapped_table, table) - + try: if self.secondary is not None: if self.secondaryjoin is None: @@ -390,109 +525,150 @@ class PropertyLoader(StrategizedProperty): if self.primaryjoin is None: self.primaryjoin = _search_for_join(self.parent, self.target).onclause except exceptions.ArgumentError, e: - raise exceptions.ArgumentError("""Error determining primary and/or secondary join for relationship '%s'. If the underlying error cannot be corrected, you should specify the 'primaryjoin' (and 'secondaryjoin', if there is an association table present) keyword arguments to the relation() function (or for backrefs, by specifying the backref using the backref() function with keyword arguments) to explicitly specify the join conditions. Nested error is \"%s\"""" % (str(self), str(e))) - - # if using polymorphic mapping, the join conditions must be agasint the base tables of the mappers, - # as the loader strategies expect to be working with those now (they will adapt the join conditions - # to the "polymorphic" selectable as needed). since this is an API change, put an explicit check/ - # error message in case its the "old" way. - if self.loads_polymorphic: - vis = sql_util.ColumnsInClause(self.mapper.select_table) - vis.traverse(self.primaryjoin) - if self.secondaryjoin: - vis.traverse(self.secondaryjoin) - if vis.result: - raise exceptions.ArgumentError("In relationship '%s', primary and secondary join conditions must not include columns from the polymorphic 'select_table' argument as of SA release 0.3.4. Construct join conditions using the base tables of the related mappers." % (str(self))) - - def _determine_fks(self): - if len(self._legacy_foreignkey) and not self._is_self_referential(): - self.foreign_keys = self._legacy_foreignkey + raise exceptions.ArgumentError("Could not determine join condition between parent/child tables on relation %s. " + "Specify a 'primaryjoin' expression. If this is a many-to-many relation, 'secondaryjoin' is needed as well." % (self)) - def col_is_part_of_mappings(col): - if self.secondary is None: - return self.parent.mapped_table.corresponding_column(col, raiseerr=False) is not None or \ - self.target.corresponding_column(col, raiseerr=False) is not None - else: - return self.parent.mapped_table.corresponding_column(col, raiseerr=False) is not None or \ - self.target.corresponding_column(col, raiseerr=False) is not None or \ - self.secondary.corresponding_column(col, raiseerr=False) is not None - - if len(self.foreign_keys): - self._opposite_side = util.Set() - def visit_binary(binary): - if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): - return - if binary.left in self.foreign_keys: - self._opposite_side.add(binary.right) - if binary.right in self.foreign_keys: - self._opposite_side.add(binary.left) - mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin) - if self.secondaryjoin is not None: - mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin) + + def __col_is_part_of_mappings(self, column): + if self.secondary is None: + return self.parent.mapped_table.c.contains_column(column) or \ + self.target.c.contains_column(column) else: - self.foreign_keys = util.Set() - self._opposite_side = util.Set() - def visit_binary(binary): - if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): - return - - # this check is for when the user put the "view_only" flag on and has tables that have nothing - # to do with the relationship's parent/child mappings in the join conditions. we dont want cols - # or clauses related to those external tables dealt with. see orm.relationships.ViewOnlyTest - if not col_is_part_of_mappings(binary.left) or not col_is_part_of_mappings(binary.right): - return - - for f in binary.left.foreign_keys: - if f.references(binary.right.table): - self.foreign_keys.add(binary.left) - self._opposite_side.add(binary.right) - for f in binary.right.foreign_keys: - if f.references(binary.left.table): - self.foreign_keys.add(binary.right) - self._opposite_side.add(binary.left) - mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin) - - if len(self.foreign_keys) == 0: - raise exceptions.ArgumentError( - "Can't locate any foreign key columns in primary join " - "condition '%s' for relationship '%s'. Specify " - "'foreign_keys' argument to indicate which columns in " - "the join condition are foreign." %(str(self.primaryjoin), str(self))) - if self.secondaryjoin is not None: - mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin) + return self.parent.mapped_table.c.contains_column(column) or \ + self.target.c.contains_column(column) or \ + self.secondary.c.contains_column(column) is not None + + def __determine_fks(self): + if self._legacy_foreignkey and not self._refers_to_parent_table(): + self.foreign_keys = self._legacy_foreignkey - def _determine_direction(self): + arg_foreign_keys = self.foreign_keys + + if self._arg_local_remote_pairs: + if not arg_foreign_keys: + raise exceptions.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument") + self.foreign_keys = util.OrderedSet(arg_foreign_keys) + self._opposite_side = util.OrderedSet() + for l, r in self._arg_local_remote_pairs: + if r in self.foreign_keys: + self._opposite_side.add(l) + elif l in self.foreign_keys: + self._opposite_side.add(r) + self.synchronize_pairs = zip(self._opposite_side, self.foreign_keys) + else: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly) + eq_pairs = [(l, r) for l, r in eq_pairs if (self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)) or r in arg_foreign_keys] + + if not eq_pairs: + if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True): + raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. " + "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.primaryjoin, self) + ) + else: + if arg_foreign_keys: + raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " + "Specify _local_remote_pairs=[(local, remote), (local, remote), ...] to explicitly establish the local/remote column pairs." % (self.primaryjoin, self)) + else: + raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " + "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self)) + + self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs]) + self._opposite_side = util.OrderedSet([l for l, r in eq_pairs]) + self.synchronize_pairs = eq_pairs + + if self.secondaryjoin: + sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly) + sq_pairs = [(l, r) for l, r in sq_pairs if (self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)) or r in arg_foreign_keys] + + if not sq_pairs: + if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True): + raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for secondaryjoin condition '%s' on relation %s. " + "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.secondaryjoin, self) + ) + else: + raise exceptions.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. " + "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.secondaryjoin, self)) + + self.foreign_keys.update([r for l, r in sq_pairs]) + self._opposite_side.update([l for l, r in sq_pairs]) + self.secondary_synchronize_pairs = sq_pairs + else: + self.secondary_synchronize_pairs = None + + def __determine_remote_side(self): + if self._arg_local_remote_pairs: + if self.remote_side: + raise exceptions.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.") + if self.direction is MANYTOONE: + eq_pairs = [(r, l) for l, r in self._arg_local_remote_pairs] + else: + eq_pairs = self._arg_local_remote_pairs + elif self.remote_side: + if self.direction is MANYTOONE: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True) + else: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True) + else: + if self.viewonly: + eq_pairs = self.synchronize_pairs + else: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True) + if self.secondaryjoin: + sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True) + eq_pairs += sq_pairs + eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)] + + if self.direction is MANYTOONE: + self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] + self.local_remote_pairs = [(r, l) for l, r in eq_pairs] + else: + self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] + self.local_remote_pairs = eq_pairs + + if self.direction is ONETOMANY: + for l in self.local_side: + if not self.__col_is_part_of_mappings(l): + raise exceptions.ArgumentError("Local column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should compare against." % (l, self.parent)) + elif self.direction is MANYTOONE: + for r in self.remote_side: + if not self.__col_is_part_of_mappings(r): + raise exceptions.ArgumentError("Remote column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should bind." % (r, self.mapper)) + + def __determine_direction(self): """Determine our *direction*, i.e. do we represent one to many, many to many, etc. """ if self.secondaryjoin is not None: - self.direction = sync.MANYTOMANY - elif self._is_self_referential(): + self.direction = MANYTOMANY + elif self._refers_to_parent_table(): # for a self referential mapper, if the "foreignkey" is a single or composite primary key, # then we are "many to one", since the remote site of the relationship identifies a singular entity. # otherwise we are "one to many". - if len(self._legacy_foreignkey): + if self._legacy_foreignkey: for f in self._legacy_foreignkey: if not f.primary_key: - self.direction = sync.ONETOMANY + self.direction = ONETOMANY else: - self.direction = sync.MANYTOONE - - elif len(self.remote_side): - for f in self.foreign_keys: - if f in self.remote_side: - self.direction = sync.ONETOMANY - return + self.direction = MANYTOONE + elif self._arg_local_remote_pairs: + remote = util.Set([r for l, r in self._arg_local_remote_pairs]) + if self.foreign_keys.intersection(remote): + self.direction = ONETOMANY + else: + self.direction = MANYTOONE + elif self.remote_side: + if self.foreign_keys.intersection(self.remote_side): + self.direction = ONETOMANY else: - self.direction = sync.MANYTOONE + self.direction = MANYTOONE else: - self.direction = sync.ONETOMANY + self.direction = ONETOMANY else: for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]: - onetomany = len([c for c in self.foreign_keys if mappedtable.c.contains_column(c)]) - manytoone = len([c for c in self.foreign_keys if parenttable.c.contains_column(c)]) + onetomany = [c for c in self.foreign_keys if mappedtable.c.contains_column(c)] + manytoone = [c for c in self.foreign_keys if parenttable.c.contains_column(c)] if not onetomany and not manytoone: raise exceptions.ArgumentError( @@ -502,10 +678,10 @@ class PropertyLoader(StrategizedProperty): elif onetomany and manytoone: continue elif onetomany: - self.direction = sync.ONETOMANY + self.direction = ONETOMANY break elif manytoone: - self.direction = sync.MANYTOONE + self.direction = MANYTOONE break else: raise exceptions.ArgumentError( @@ -514,66 +690,16 @@ class PropertyLoader(StrategizedProperty): "the child's mapped tables. Specify 'foreign_keys' " "argument." % (str(self))) - def _determine_remote_side(self): - if not len(self.remote_side): - if self.direction is sync.MANYTOONE: - self.remote_side = util.Set(self._opposite_side) - elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY: - self.remote_side = util.Set(self.foreign_keys) - - self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side) - - def _create_polymorphic_joins(self): - # get ready to create "polymorphic" primary/secondary join clauses. - # these clauses represent the same join between parent/child tables that the primary - # and secondary join clauses represent, except they reference ColumnElements that are specifically - # in the "polymorphic" selectables. these are used to construct joins for both Query as well as - # eager loading, and also are used to calculate "lazy loading" clauses. - - # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out, - # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge - # several "equivalent" columns (such as parent/child fk cols) into just one column. - - target_equivalents = self.mapper._get_equivalent_columns() - - # if the target mapper loads polymorphically, adapt the clauses to the target's selectable - if self.loads_polymorphic: - if self.secondaryjoin: - self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True) - self.polymorphic_primaryjoin = self.primaryjoin - else: - if self.direction is sync.ONETOMANY: - self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) - elif self.direction is sync.MANYTOONE: - self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) - self.polymorphic_secondaryjoin = None - # load "polymorphic" versions of the columns present in "remote_side" - this is - # important for lazy-clause generation which goes off the polymorphic target selectable - for c in list(self.remote_side): - if self.secondary and self.secondary.columns.contains_column(c): - continue - for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []): - corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False) - if corr: - self.remote_side.add(corr) - break - else: - raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table)) - else: - self.polymorphic_primaryjoin = self.primaryjoin - self.polymorphic_secondaryjoin = self.secondaryjoin - def _post_init(self): if logging.is_info_enabled(self.logger): - self.logger.info(str(self) + " setup primary join " + str(self.primaryjoin)) - self.logger.info(str(self) + " setup polymorphic primary join " + str(self.polymorphic_primaryjoin)) - self.logger.info(str(self) + " setup secondary join " + str(self.secondaryjoin)) - self.logger.info(str(self) + " setup polymorphic secondary join " + str(self.polymorphic_secondaryjoin)) - self.logger.info(str(self) + " foreign keys " + str([str(c) for c in self.foreign_keys])) - self.logger.info(str(self) + " remote columns " + str([str(c) for c in self.remote_side])) - self.logger.info(str(self) + " relation direction " + (self.direction is sync.ONETOMANY and "one-to-many" or (self.direction is sync.MANYTOONE and "many-to-one" or "many-to-many"))) - - if self.uselist is None and self.direction is sync.MANYTOONE: + self.logger.info(str(self) + " setup primary join %s" % self.primaryjoin) + self.logger.info(str(self) + " setup secondary join %s" % self.secondaryjoin) + self.logger.info(str(self) + " synchronize pairs [%s]" % ",".join(["(%s => %s)" % (l, r) for l, r in self.synchronize_pairs])) + self.logger.info(str(self) + " secondary synchronize pairs [%s]" % ",".join(["(%s => %s)" % (l, r) for l, r in self.secondary_synchronize_pairs or []])) + self.logger.info(str(self) + " local/remote pairs [%s]" % ",".join(["(%s / %s)" % (l, r) for l, r in self.local_remote_pairs])) + self.logger.info(str(self) + " relation direction %s" % self.direction) + + if self.uselist is None and self.direction is MANYTOONE: self.uselist = False if self.uselist is None: @@ -591,41 +717,73 @@ class PropertyLoader(StrategizedProperty): if self.backref is not None: self.backref.compile(self) - elif not sessionlib.attribute_manager.is_class_managed(self.parent.class_, self.key): + elif not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False): raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__)) super(PropertyLoader, self).do_init() + def _refers_to_parent_table(self): + return self.parent.mapped_table is self.target or self.parent.mapped_table is self.target + def _is_self_referential(self): - return self.parent.mapped_table is self.target or self.parent.select_table is self.target - - def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True): - try: - return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] - except KeyError: - parent_equivalents = parent._get_equivalent_columns() - secondaryjoin = self.polymorphic_secondaryjoin - if polymorphic_parent: - # adapt the "parent" side of our join condition to the "polymorphic" select of the parent - if self.direction is sync.ONETOMANY: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) - elif self.direction is sync.MANYTOONE: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) - elif self.secondaryjoin: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) - - if secondaryjoin is not None: - if secondary and not primary: - j = secondaryjoin - elif primary and secondary: - j = primaryjoin & secondaryjoin - elif primary and not secondary: - j = primaryjoin + return self.mapper.common_parent(self.parent) + + def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None): + if source_selectable is None: + if source_polymorphic and self.parent.with_polymorphic: + source_selectable = self.parent._with_polymorphic_selectable() else: - j = primaryjoin - self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j - return j - + source_selectable = None + if dest_selectable is None: + if dest_polymorphic and self.mapper.with_polymorphic: + dest_selectable = self.mapper._with_polymorphic_selectable() + else: + dest_selectable = self.mapper.mapped_table + if self._is_self_referential(): + if dest_selectable: + dest_selectable = dest_selectable.alias() + else: + dest_selectable = self.mapper.mapped_table.alias() + + primaryjoin = self.primaryjoin + if source_selectable: + if self.direction in (ONETOMANY, MANYTOMANY): + primaryjoin = ClauseAdapter(source_selectable, exclude=self.foreign_keys, equivalents=self.parent._equivalent_columns).traverse(primaryjoin) + else: + primaryjoin = ClauseAdapter(source_selectable, include=self.foreign_keys, equivalents=self.parent._equivalent_columns).traverse(primaryjoin) + + secondaryjoin = self.secondaryjoin + target_adapter = None + if dest_selectable: + if self.direction == ONETOMANY: + target_adapter = ClauseAdapter(dest_selectable, include=self.foreign_keys, equivalents=self.mapper._equivalent_columns) + elif self.direction == MANYTOMANY: + target_adapter = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns) + else: + target_adapter = ClauseAdapter(dest_selectable, exclude=self.foreign_keys, equivalents=self.mapper._equivalent_columns) + if secondaryjoin: + secondaryjoin = target_adapter.traverse(secondaryjoin) + else: + primaryjoin = target_adapter.traverse(primaryjoin) + target_adapter.include = target_adapter.exclude = None + + return primaryjoin, secondaryjoin, source_selectable or self.parent.local_table, dest_selectable or self.mapper.local_table, target_adapter + + def _get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True): + """deprecated. use primary_join_against(), secondary_join_against(), full_join_against()""" + + pj, sj, source, dest, adapter = self._create_joins(source_polymorphic=polymorphic_parent) + + if primary and secondary: + return pj & sj + elif primary: + return pj + elif secondary: + return sj + else: + raise AssertionError("illegal condition") + + def register_dependencies(self, uowcommit): if not self.viewonly: self._dependency_processor.register_dependencies(uowcommit) @@ -633,54 +791,48 @@ class PropertyLoader(StrategizedProperty): PropertyLoader.logger = logging.class_logger(PropertyLoader) class BackRef(object): - """Stores the name of a backreference property as well as options - to be used on the resulting PropertyLoader. - """ + """Attached to a PropertyLoader to indicate a complementary reverse relationship. + + Can optionally create the complementing PropertyLoader if one does not exist already.""" - def __init__(self, key, **kwargs): + def __init__(self, key, _prop=None, **kwargs): self.key = key self.kwargs = kwargs + self.prop = _prop def compile(self, prop): - """Called by the owning PropertyLoader to set up a - backreference on the PropertyLoader's mapper. - """ + if self.prop: + return + + self.prop = prop - # try to set a LazyLoader on our mapper referencing the parent mapper mapper = prop.mapper.primary_mapper() - if not mapper.get_property(self.key, raiseerr=False) is not None: + if mapper._get_property(self.key, raiseerr=False) is None: pj = self.kwargs.pop('primaryjoin', None) sj = self.kwargs.pop('secondaryjoin', None) - # the backref property is set on the primary mapper + parent = prop.parent.primary_mapper() self.kwargs.setdefault('viewonly', prop.viewonly) self.kwargs.setdefault('post_update', prop.post_update) + relation = PropertyLoader(parent, prop.secondary, pj, sj, - backref=prop.key, is_backref=True, + backref=BackRef(prop.key, _prop=prop), + is_backref=True, **self.kwargs) + mapper._compile_property(self.key, relation); - elif not isinstance(mapper.get_property(self.key), PropertyLoader): - raise exceptions.ArgumentError( - "Can't create backref '%s' on mapper '%s'; an incompatible " - "property of that name already exists" % (self.key, str(mapper))) + + prop._reverse_property = mapper._get_property(self.key) + mapper._get_property(self.key)._reverse_property = prop + else: - # else set one of us as the "backreference" - parent = prop.parent.primary_mapper() - if parent.class_ is not mapper.get_property(self.key)._get_target_class(): - raise exceptions.ArgumentError( - "Backrefs do not match: backref '%s' expects to connect to %s, " - "but found a backref already connected to %s" % - (self.key, str(parent.class_), str(mapper.get_property(self.key).mapper.class_))) - if not mapper.get_property(self.key).is_backref: - prop.is_backref=True - if not prop.viewonly: - prop._dependency_processor.is_backref=True - # reverse_property used by dependencies.ManyToManyDP to check - # association table operations - prop.reverse_property = mapper.get_property(self.key) - mapper.get_property(self.key).reverse_property = prop + raise exceptions.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper)) def get_extension(self): """Return an attribute extension to use with this backreference.""" return attributes.GenericBackrefExtension(self.key) + +mapper.ColumnProperty = ColumnProperty +mapper.SynonymProperty = SynonymProperty +mapper.ComparableProperty = ComparableProperty diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 284653b5c5..8996a758e6 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1,32 +1,48 @@ # orm/query.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, sql_util, logging +"""The Query class and support. + +Defines the [sqlalchemy.orm.query#Query] class, the central construct used by +the ORM to construct database queries. + +The ``Query`` class should not be confused with the +[sqlalchemy.sql.expression#Select] class, which defines database SELECT +operations at the SQL (non-ORM) level. ``Query`` differs from ``Select`` in +that it returns ORM-mapped objects and interacts with an ORM session, whereas +the ``Select`` construct interacts directly with the database to return +iterable result sets. +""" + +from itertools import chain +from sqlalchemy import sql, util, exceptions, logging +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import expression, visitors, operators from sqlalchemy.orm import mapper, object_mapper + +from sqlalchemy.orm.util import _state_mapper, _class_to_mapper, _is_mapped_class, _is_aliased_class from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.interfaces import OperationContext, LoaderStack -import operator +from sqlalchemy.orm import interfaces +from sqlalchemy.orm import attributes +from sqlalchemy.orm.util import AliasedClass + +aliased = AliasedClass + +__all__ = ['Query', 'QueryContext', 'aliased'] -__all__ = ['Query', 'QueryContext', 'SelectionContext'] class Query(object): """Encapsulates the object-fetching operations provided by Mappers.""" - + def __init__(self, class_or_mapper, session=None, entity_name=None): - if isinstance(class_or_mapper, type): - self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name) - else: - self.mapper = class_or_mapper.compile() - self.select_mapper = self.mapper.get_select_mapper().compile() - self._session = session - + self._with_options = [] self._lockmode = None - self._extension = self.mapper.extension.copy() + self._entities = [] self._order_by = False self._group_by = False @@ -35,113 +51,297 @@ class Query(object): self._limit = None self._statement = None self._params = {} + self._yield_per = None self._criterion = None + self.__joinable_tables = None + self._having = None self._column_aggregate = None - self._joinpoint = self.mapper - self._aliases = None - self._alias_ids = {} - self._from_obj = [self.table] self._populate_existing = False self._version_check = False + self._autoflush = True + + self._attributes = {} + self._current_path = () + self._only_load_props = None + self._refresh_instance = None + + self.__init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name)) + + def __init_mapper(self, mapper): + """populate all instance variables derived from this Query's mapper.""" + self.mapper = mapper + self.table = self._from_obj = self.mapper.mapped_table + self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]])) + self._extension = self.mapper.extension + self._aliases_head = self._aliases_tail = None + self._alias_ids = {} + self._joinpoint = self.mapper + self._entities.append(_PrimaryMapperEntity(self.mapper)) + if self.mapper.with_polymorphic: + self.__set_with_polymorphic(*self.mapper.with_polymorphic) + else: + self._with_polymorphic = [] + + def __generate_alias_ids(self): + self._alias_ids = dict([ + (k, list(v)) for k, v in self._alias_ids.iteritems() + ]) + + def __no_criterion(self, meth): + return self.__conditional_clone(meth, [self.__no_criterion_condition]) + + def __no_statement(self, meth): + return self.__conditional_clone(meth, [self.__no_statement_condition]) + + def __reset_all(self, mapper, meth): + q = self.__conditional_clone(meth, [self.__no_criterion_condition]) + q.__init_mapper(mapper, mapper) + return q + + def __set_select_from(self, from_obj): + if isinstance(from_obj, expression._SelectBaseMixin): + # alias SELECTs and unions + from_obj = from_obj.alias() + + self._from_obj = from_obj + self._alias_ids = {} + + if self.table not in self._get_joinable_tables(): + self._aliases_head = self._aliases_tail = mapperutil.AliasedClauses(self._from_obj, equivalents=self.mapper._equivalent_columns) + self._alias_ids.setdefault(self.table, []).append(self._aliases_head) + else: + self._aliases_head = self._aliases_tail = None + + def __set_with_polymorphic(self, cls_or_mappers, selectable=None): + mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable) + self._with_polymorphic = mappers + self.__set_select_from(from_obj) + + def __no_criterion_condition(self, q, meth): + if q._criterion or q._statement: + util.warn( + ("Query.%s() being called on a Query with existing criterion; " + "criterion is being ignored.") % meth) + + q._joinpoint = self.mapper + q._statement = q._criterion = None + q._order_by = q._group_by = q._distinct = False + q._aliases_tail = q._aliases_head + q.table = q._from_obj = q.mapper.mapped_table + if q.mapper.with_polymorphic: + q.__set_with_polymorphic(*q.mapper.with_polymorphic) + + def __no_entities(self, meth): + q = self.__no_statement(meth) + if len(q._entities) > 1 and not isinstance(q._entities[0], _PrimaryMapperEntity): + raise exceptions.InvalidRequestError( + ("Query.%s() being called on a Query with existing " + "additional entities or columns - can't replace columns") % meth) + q._entities = [] + return q + + def __no_statement_condition(self, q, meth): + if q._statement: + raise exceptions.InvalidRequestError( + ("Query.%s() being called on a Query with an existing full " + "statement - can't apply criterion.") % meth) + + def __conditional_clone(self, methname=None, conditions=None): + q = self._clone() + if conditions: + for condition in conditions: + condition(q, methname) + return q + + def __get_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None): + if populate_existing: + self._populate_existing = populate_existing + if version_check: + self._version_check = version_check + if refresh_instance: + self._refresh_instance = refresh_instance + if only_load_props: + self._only_load_props = util.Set(only_load_props) + return self + def _clone(self): q = Query.__new__(Query) q.__dict__ = self.__dict__.copy() return q - - def _get_session(self): + + def session(self): if self._session is None: return self.mapper.get_session() else: return self._session + session = property(session) + + def statement(self): + """return the full SELECT statement represented by this Query.""" + return self._compile_context().statement + statement = property(statement) + + def whereclause(self): + """return the WHERE criterion for this Query.""" + return self._criterion + whereclause = property(whereclause) - table = property(lambda s:s.select_mapper.mapped_table) - primary_key_columns = property(lambda s:s.select_mapper.primary_key) - session = property(_get_session) + def _with_current_path(self, path): + """indicate that this query applies to objects loaded within a certain path. + + Used by deferred loaders (see strategies.py) which transfer query + options from an originating query to a newly generated query intended + for the deferred load. + + """ + q = self._clone() + q._current_path = path + return q + + def with_polymorphic(self, cls_or_mappers, selectable=None): + """Load columns for descendant mappers of this Query's mapper. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + ``cls_or_mappers`` is a single class or mapper, or list of class/mappers, + which inherit from this Query's mapper. Alternatively, it + may also be the string ``'*'``, in which case all descending + mappers will be added to the FROM clause. + + ``selectable`` is a table or select() statement that will + be used in place of the generated FROM clause. This argument + is required if any of the desired mappers use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` + argument must represent the full set of tables and columns mapped + by every desired mapper. Otherwise, the unaccounted mapped columns + will result in their table being appended directly to the FROM + clause which will usually lead to incorrect results. + + """ + q = self.__no_criterion('with_polymorphic') + + q.__set_with_polymorphic(cls_or_mappers, selectable=selectable) + + return q + + + def yield_per(self, count): + """Yield only ``count`` rows at a time. + + WARNING: use this method with caution; if the same instance is present + in more than one batch of rows, end-user changes to attributes will be + overwritten. + + In particular, it's usually impossible to use this setting with + eagerly loaded collections (i.e. any lazy=False) since those + collections will be cleared for a new load when encountered in a + subsequent result batch. + """ + + q = self._clone() + q._yield_per = count + return q def get(self, ident, **kwargs): - """Return an instance of the object based on the given - identifier, or None if not found. + """Return an instance of the object based on the given identifier, or None if not found. - The `ident` argument is a scalar or tuple of primary key - column values in the order of the table def's primary key - columns. + The `ident` argument is a scalar or tuple of primary key column values + in the order of the table def's primary key columns. """ ret = self._extension.get(self, ident, **kwargs) - if ret is not mapper.EXT_PASS: + if ret is not mapper.EXT_CONTINUE: return ret # convert composite types to individual args - # TODO: account for the order of columns in the + # TODO: account for the order of columns in the # ColumnProperty it corresponds to - if hasattr(ident, '__colset__'): - ident = ident.__colset__() + if hasattr(ident, '__composite_values__'): + ident = ident.__composite_values__() key = self.mapper.identity_key_from_primary_key(ident) return self._get(key, ident, **kwargs) def load(self, ident, raiseerr=True, **kwargs): - """Return an instance of the object based on the given - identifier. - - If not found, raises an exception. The method will **remove - all pending changes** to the object already existing in the - Session. The `ident` argument is a scalar or tuple of primary - key column values in the order of the table def's primary key - columns. + """Return an instance of the object based on the given identifier. + + If not found, raises an exception. The method will **remove all + pending changes** to the object already existing in the Session. The + `ident` argument is a scalar or tuple of primary key column values in + the order of the table def's primary key columns. """ ret = self._extension.load(self, ident, **kwargs) - if ret is not mapper.EXT_PASS: + if ret is not mapper.EXT_CONTINUE: return ret key = self.mapper.identity_key_from_primary_key(ident) - instance = self._get(key, ident, reload=True, **kwargs) + instance = self.populate_existing()._get(key, ident, **kwargs) if instance is None and raiseerr: raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) return instance - + def query_from_parent(cls, instance, property, **kwargs): - """return a newly constructed Query object, with criterion corresponding to - a relationship to the given parent instance. + """Return a new Query with criterion corresponding to a parent instance. - instance - a persistent or detached instance which is related to class represented - by this query. + Return a newly constructed Query object, with criterion corresponding + to a relationship to the given parent instance. - property - string name of the property which relates this query's class to the - instance. - - \**kwargs - all extra keyword arguments are propigated to the constructor of - Query. - + instance + a persistent or detached instance which is related to class + represented by this query. + + property + string name of the property which relates this query's class to the + instance. + + \**kwargs + all extra keyword arguments are propagated to the constructor of + Query. """ - + mapper = object_mapper(instance) prop = mapper.get_property(property, resolve_synonyms=True) target = prop.mapper - criterion = prop.compare(operator.eq, instance, value_is_parent=True) + criterion = prop.compare(operators.eq, instance, value_is_parent=True) return Query(target, **kwargs).filter(criterion) query_from_parent = classmethod(query_from_parent) - - def populate_existing(self): - """return a Query that will refresh all instances loaded. + + def autoflush(self, setting): + """Return a Query with a specific 'autoflush' setting. + + Note that a Session with autoflush=False will + not autoflush, even if this flag is set to True at the + Query level. Therefore this flag is usually used only + to disable autoflush for a specific Query. - this includes all entities accessed from the database, including + """ + q = self._clone() + q._autoflush = setting + return q + + def populate_existing(self): + """Return a Query that will refresh all instances loaded. + + This includes all entities accessed from the database, including secondary entities, eagerly-loaded collection items. + + All changes present on entities which are already present in the + session will be reset and the entities will all be marked "clean". + + An alternative to populate_existing() is to expire the Session + fully using session.expire_all(). - All changes present on entities which are already present in the session will - be reset and the entities will all be marked "clean". - - This is essentially the en-masse version of load(). """ - q = self._clone() q._populate_existing = True return q - + def with_parent(self, instance, property=None): """add a join criterion corresponding to a relationship to the given parent instance. @@ -150,13 +350,13 @@ class Query(object): by this query. property - string name of the property which relates this query's class to the + string name of the property which relates this query's class to the instance. if None, the method will attempt to find a suitable property. currently, this method only works with immediate parent relationships, but in the - future may be enhanced to work across a chain of parent mappers. - """ + future may be enhanced to work across a chain of parent mappers. + """ from sqlalchemy.orm import properties mapper = object_mapper(instance) if property is None: @@ -167,113 +367,169 @@ class Query(object): raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__)) else: prop = mapper.get_property(property, resolve_synonyms=True) - return self.filter(prop.compare(operator.eq, instance, value_is_parent=True)) + return self.filter(prop.compare(operators.eq, instance, value_is_parent=True)) def add_entity(self, entity, alias=None, id=None): """add a mapped entity to the list of result columns to be returned. - + This will have the effect of all result-returning methods returning a tuple - of results, the first element being an instance of the primary class for this + of results, the first element being an instance of the primary class for this Query, and subsequent elements matching columns or entities which were specified via add_column or add_entity. - - When adding entities to the result, its generally desireable to add + + When adding entities to the result, its generally desirable to add limiting criterion to the query which can associate the primary entity of this Query along with the additional entities. The Query selects from all tables with no joining criterion by default. - + entity a class or mapper which will be added to the results. - + alias a sqlalchemy.sql.Alias object which will be used to select rows. this will match the usage of the given Alias in filter(), order_by(), etc. expressions - + id - a string ID matching that given to query.join() or query.outerjoin(); rows will be + a string ID matching that given to query.join() or query.outerjoin(); rows will be selected from the aliased join created via those methods. + """ q = self._clone() + if not alias and _is_aliased_class(entity): + alias = entity.alias + if isinstance(entity, type): entity = mapper.class_mapper(entity) + if alias is not None: - alias = mapperutil.AliasedClauses(entity.mapped_table, alias=alias) + alias = mapperutil.AliasedClauses(alias) - q._entities = q._entities + [(entity, alias, id)] + q._entities = q._entities + [_MapperEntity(mapper=entity, alias=alias, id=id)] return q + + def _from_self(self): + """return a Query that selects from this Query's SELECT statement. - def add_column(self, column, id=None): - """add a SQL ColumnElement to the list of result columns to be returned. + The API for this method hasn't been decided yet and is subject to change. + + """ + q = self._clone() + q._eager_loaders = util.Set() + fromclause = q.compile().correlate(None) + return Query(self.mapper, self.session).select_from(fromclause) - This will have the effect of all result-returning methods returning a tuple - of results, the first element being an instance of the primary class for this - Query, and subsequent elements matching columns or entities which were - specified via add_column or add_entity. + def values(self, *columns): + """Return an iterator yielding result tuples corresponding to the given list of columns""" + + q = self.__no_entities('_values') + q._only_load_props = q._eager_loaders = util.Set() + q._no_filters = True + for column in columns: + q._entities.append(self._add_column(column, None, False)) + if not q._yield_per: + q = q.yield_per(10) + return iter(q) + _values = values + + def add_column(self, column, id=None): + """Add a SQL ColumnElement to the list of result columns to be returned. + + This will have the effect of all result-returning methods returning a + tuple of results, the first element being an instance of the primary + class for this Query, and subsequent elements matching columns or + entities which were specified via add_column or add_entity. - When adding columns to the result, its generally desireable to add + When adding columns to the result, its generally desirable to add limiting criterion to the query which can associate the primary entity - of this Query along with the additional columns, if the column is based on a - table or selectable that is not the primary mapped selectable. The Query selects - from all tables with no joining criterion by default. - - column - a string column name or sql.ColumnElement to be added to the results. - + of this Query along with the additional columns, if the column is + based on a table or selectable that is not the primary mapped + selectable. The Query selects from all tables with no joining + criterion by default. + + column + a string column name or sql.ColumnElement to be added to the results. + """ - q = self._clone() + q._entities = q._entities + [self._add_column(column, id, True)] + return q + + def _add_column(self, column, id, looks_for_aliases): + if isinstance(column, interfaces.PropComparator): + column = column.clause_element() - # alias non-labeled column elements. - if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'): - column = column.label(None) + elif not isinstance(column, (sql.ColumnElement, basestring)): + raise exceptions.InvalidRequestError("Invalid column expression '%r'" % column) - q._entities = q._entities + [(column, None, id)] - return q + return _ColumnEntity(column, id) def options(self, *args): """Return a new Query object, applying the given list of MapperOptions. + """ - + return self._options(False, *args) + + def _conditional_options(self, *args): + return self._options(True, *args) + + def _options(self, conditional, *args): q = self._clone() + # most MapperOptions write to the '_attributes' dictionary, + # so copy that as well + q._attributes = q._attributes.copy() opts = [o for o in util.flatten_iterator(args)] q._with_options = q._with_options + opts - for opt in opts: - opt.process_query(q) + if conditional: + for opt in opts: + opt.process_query_conditionally(q) + else: + for opt in opts: + opt.process_query(q) return q def with_lockmode(self, mode): """Return a new Query object with the specified locking mode.""" + q = self._clone() q._lockmode = mode return q - def params(self, **kwargs): - """add values for bind parameters which may have been specified in filter().""" - + def params(self, *args, **kwargs): + """add values for bind parameters which may have been specified in filter(). + + parameters may be specified using \**kwargs, or optionally a single dictionary + as the first positional argument. The reason for both is that \**kwargs is + convenient, however some parameter dictionaries contain unicode keys in which case + \**kwargs cannot be used. + + """ q = self._clone() + if len(args) == 1: + kwargs.update(args[0]) + elif len(args) > 0: + raise exceptions.ArgumentError("params() takes zero or one positional argument, which is a dictionary.") q._params = q._params.copy() q._params.update(kwargs) return q - + def filter(self, criterion): """apply the given filtering criterion to the query and return the newly resulting ``Query`` - + the criterion is any sql.ClauseElement applicable to the WHERE clause of a select. + """ - if isinstance(criterion, basestring): criterion = sql.text(criterion) - + if criterion is not None and not isinstance(criterion, sql.ClauseElement): raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") - - - if self._aliases is not None: - criterion = self._aliases.adapt_clause(criterion) - - q = self._clone() + + if self._aliases_tail: + criterion = self._aliases_tail.adapt_clause(criterion) + + q = self.__no_statement("filter") if q._criterion is not None: q._criterion = q._criterion & criterion else: @@ -283,130 +539,11 @@ class Query(object): def filter_by(self, **kwargs): """apply the given filtering criterion to the query and return the newly resulting ``Query``.""" - #import properties - - alias = None - join = None - clause = None - joinpoint = self._joinpoint - - for key, value in kwargs.iteritems(): - prop = joinpoint.get_property(key, resolve_synonyms=True) - c = prop.compare(operator.eq, value) - - if alias is not None: - sql_util.ClauseAdapter(alias).traverse(c) - if clause is None: - clause = c - else: - clause &= c - - if join is not None: - return self.select_from(join).filter(clause) - else: - return self.filter(clause) - - def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True): - if start is None: - start = self._joinpoint - - clause = self._from_obj[-1] - - currenttables = [clause] - class FindJoinedTables(sql.NoColumnVisitor): - def visit_join(self, join): - currenttables.append(join.left) - currenttables.append(join.right) - FindJoinedTables().traverse(clause) - - mapper = start - alias = self._aliases - for key in util.to_list(keys): - prop = mapper.get_property(key, resolve_synonyms=True) - if prop._is_self_referential() and not create_aliases: - raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop)) - - if prop.select_table not in currenttables or create_aliases: - if prop.secondary: - if create_aliases: - alias = mapperutil.PropertyAliasedClauses(prop, - prop.get_join(mapper, primary=True, secondary=False), - prop.get_join(mapper, primary=False, secondary=True), - alias - ) - clause = clause.join(alias.secondary, alias.primaryjoin, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin) - else: - clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False), isouter=outerjoin) - clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False), isouter=outerjoin) - else: - if create_aliases: - alias = mapperutil.PropertyAliasedClauses(prop, - prop.get_join(mapper, primary=True, secondary=False), - None, - alias - ) - clause = clause.join(alias.alias, alias.primaryjoin, isouter=outerjoin) - else: - clause = clause.join(prop.select_table, prop.get_join(mapper), isouter=outerjoin) - elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables: - # TODO: this check is not strong enough for different paths to the same endpoint which - # does not use secondary tables - raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key) - - mapper = prop.mapper - - if create_aliases: - return (clause, mapper, alias) - else: - return (clause, mapper, None) - - def _generative_col_aggregate(self, col, func): - """apply the given aggregate function to the query and return the newly - resulting ``Query``. - """ - if self._column_aggregate is not None: - raise exceptions.InvalidRequestError("Query already contains an aggregate column or function") - q = self._clone() - q._column_aggregate = (col, func) - return q - - def apply_min(self, col): - """apply the SQL ``min()`` function against the given column to the - query and return the newly resulting ``Query``. - """ - return self._generative_col_aggregate(col, sql.func.min) - - def apply_max(self, col): - """apply the SQL ``max()`` function against the given column to the - query and return the newly resulting ``Query``. - """ - return self._generative_col_aggregate(col, sql.func.max) - - def apply_sum(self, col): - """apply the SQL ``sum()`` function against the given column to the - query and return the newly resulting ``Query``. - """ - return self._generative_col_aggregate(col, sql.func.sum) - - def apply_avg(self, col): - """apply the SQL ``avg()`` function against the given column to the - query and return the newly resulting ``Query``. - """ - return self._generative_col_aggregate(col, sql.func.avg) - - def _col_aggregate(self, col, func): - """Execute ``func()`` function against the given column. + clauses = [self._joinpoint.get_property(key, resolve_synonyms=True).compare(operators.eq, value) + for key, value in kwargs.iteritems()] - For performance, only use subselect if `order_by` attribute is set. - """ + return self.filter(sql.and_(*clauses)) - ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj} - - if self._order_by is not False: - s1 = sql.select([col], self._criterion, **ops).alias('u') - return self.session.execute(sql.select([func(s1.corresponding_column(col))]), mapper=self.mapper).scalar() - else: - return self.session.execute(sql.select([func(col)], self._criterion, **ops), mapper=self.mapper).scalar() def min(self, col): """Execute the SQL ``min()`` function against the given column.""" @@ -427,91 +564,261 @@ class Query(object): """Execute the SQL ``avg()`` function against the given column.""" return self._col_aggregate(col, sql.func.avg) - - def order_by(self, criterion): + + def order_by(self, *criterion): """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" - q = self._clone() - if q._order_by is False: - q._order_by = util.to_list(criterion) + q = self.__no_statement("order_by") + + if self._aliases_tail: + criterion = tuple(self._aliases_tail.adapt_list( + [expression._literal_as_text(o) for o in criterion] + )) + + if q._order_by is False: + q._order_by = criterion else: - q._order_by = q._order_by + util.to_list(criterion) + q._order_by = q._order_by + criterion return q - - def group_by(self, criterion): + order_by = util.array_as_starargs_decorator(order_by) + + def group_by(self, *criterion): """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``""" - q = self._clone() - if q._group_by is False: - q._group_by = util.to_list(criterion) + q = self.__no_statement("group_by") + if q._group_by is False: + q._group_by = criterion + else: + q._group_by = q._group_by + criterion + return q + group_by = util.array_as_starargs_decorator(group_by) + + def having(self, criterion): + """apply a HAVING criterion to the query and return the newly resulting ``Query``.""" + + if isinstance(criterion, basestring): + criterion = sql.text(criterion) + + if criterion is not None and not isinstance(criterion, sql.ClauseElement): + raise exceptions.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string") + + if self._aliases_tail: + criterion = self._aliases_tail.adapt_clause(criterion) + + q = self.__no_statement("having") + if q._having is not None: + q._having = q._having & criterion else: - q._group_by = q._group_by + util.to_list(criterion) + q._having = criterion return q def join(self, prop, id=None, aliased=False, from_joinpoint=False): - """create a join of this ``Query`` object's criterion - to a relationship and return the newly resulting ``Query``. + """Create a join against this ``Query`` object's criterion + and apply generatively, retunring the newly resulting ``Query``. + + 'prop' may be one of: + * a string property name, i.e. "rooms" + * a class-mapped attribute, i.e. Houses.rooms + * a 2-tuple containing one of the above, combined with a selectable + which derives from the properties' mapped table + * a list (not a tuple) containing a combination of any of the above. + + e.g.:: + + session.query(Company).join('employees') + session.query(Company).join(['employees', 'tasks']) + session.query(Houses).join([Colonials.rooms, Room.closets]) + session.query(Company).join([('employees', people.join(engineers)), Engineer.computers]) + + """ + return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint) + + def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False): + """Create a left outer join against this ``Query`` object's criterion + and apply generatively, retunring the newly resulting ``Query``. + + 'prop' may be one of: + * a string property name, i.e. "rooms" + * a class-mapped attribute, i.e. Houses.rooms + * a 2-tuple containing one of the above, combined with a selectable + which derives from the properties' mapped table + * a list (not a tuple) containing a combination of any of the above. + + e.g.:: + + session.query(Company).outerjoin('employees') + session.query(Company).outerjoin(['employees', 'tasks']) + session.query(Houses).outerjoin([Colonials.rooms, Room.closets]) + session.query(Company).join([('employees', people.join(engineers)), Engineer.computers]) + + """ + return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint) + + def _join(self, prop, id, outerjoin, aliased, from_joinpoint): + (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased) + # TODO: improve the generative check here to look for primary mapped entity, etc. + q = self.__no_statement("join") + q._from_obj = clause + q._joinpoint = mapper + q._aliases = aliases + q.__generate_alias_ids() + + if aliases: + q._aliases_tail = aliases + + a = aliases + while a is not None: + if isinstance(a, mapperutil.PropertyAliasedClauses): + q._alias_ids.setdefault(a.mapper, []).append(a) + q._alias_ids.setdefault(a.table, []).append(a) + a = a.parentclauses + else: + break + + if id: + q._alias_ids[id] = [aliases] + return q + + def _get_joinable_tables(self): + if not self.__joinable_tables or self.__joinable_tables[0] is not self._from_obj: + currenttables = [self._from_obj] + def visit_join(join): + currenttables.append(join.left) + currenttables.append(join.right) + visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False}) + self.__joinable_tables = (self._from_obj, currenttables) + return currenttables + else: + return self.__joinable_tables[1] + + def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True): + if start is None: + start = self._joinpoint + + clause = self._from_obj + + currenttables = self._get_joinable_tables() + + # determine if generated joins need to be aliased on the left + # hand side. + if self._aliases_head is self._aliases_tail is not None: + adapt_against = self._aliases_tail.alias + elif start is not self.mapper and self._aliases_tail: + adapt_against = self._aliases_tail.alias + else: + adapt_against = None - 'prop' may be a string property name or a list of string - property names. - """ + mapper = start + alias = self._aliases_tail - return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint) - - def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False): - """create a left outer join of this ``Query`` object's criterion - to a relationship and return the newly resulting ``Query``. - - 'prop' may be a string property name or a list of string - property names. - """ + if not isinstance(keys, list): + keys = [keys] + + for key in keys: + use_selectable = None + of_type = None + is_aliased_class = False + + if isinstance(key, tuple): + key, use_selectable = key + + if isinstance(key, interfaces.PropComparator): + prop = key.property + if getattr(key, '_of_type', None): + of_type = key._of_type + if not use_selectable: + use_selectable = key._of_type.mapped_table + else: + prop = mapper.get_property(key, resolve_synonyms=True) + + if use_selectable: + if _is_aliased_class(use_selectable): + use_selectable = use_selectable.alias + is_aliased_class = True + if not use_selectable.is_derived_from(prop.mapper.mapped_table): + raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description)) + if not isinstance(use_selectable, expression.Alias): + use_selectable = use_selectable.alias() + elif prop.mapper.with_polymorphic: + use_selectable = prop.mapper._with_polymorphic_selectable() + if not isinstance(use_selectable, expression.Alias): + use_selectable = use_selectable.alias() + + if prop._is_self_referential() and not create_aliases and not use_selectable: + raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires aliased=True argument." % str(prop)) + + if prop.table not in currenttables or create_aliases or use_selectable: + + if use_selectable or create_aliases: + alias = mapperutil.PropertyAliasedClauses(prop, + prop.primaryjoin, + prop.secondaryjoin, + alias, + alias=use_selectable, + should_adapt=not is_aliased_class + ) + crit = alias.primaryjoin + if prop.secondary: + clause = clause.join(alias.secondary, crit, isouter=outerjoin) + clause = clause.join(alias.alias, alias.secondaryjoin, isouter=outerjoin) + else: + clause = clause.join(alias.alias, crit, isouter=outerjoin) + else: + assert not prop.mapper.with_polymorphic + pj, sj, source, dest, target_adapter = prop._create_joins(source_selectable=adapt_against) + if sj: + clause = clause.join(prop.secondary, pj, isouter=outerjoin) + clause = clause.join(prop.table, sj, isouter=outerjoin) + else: + clause = clause.join(prop.table, pj, isouter=outerjoin) + + elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables: + # TODO: this check is not strong enough for different paths to the same endpoint which + # does not use secondary tables + raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key) - return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint) + mapper = of_type or prop.mapper - def _join(self, prop, id, outerjoin, aliased, from_joinpoint): - (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased) - q = self._clone() - q._from_obj = [clause] - q._joinpoint = mapper - q._aliases = aliases + if use_selectable: + adapt_against = use_selectable - a = aliases - while a is not None: - q._alias_ids.setdefault(a.mapper, []).append(a) - q._alias_ids.setdefault(a.table, []).append(a) - q._alias_ids.setdefault(a.alias, []).append(a) - a = a.parentclauses - - if id: - q._alias_ids[id] = aliases - return q + return (clause, mapper, alias) + def reset_joinpoint(self): - """return a new Query reset the 'joinpoint' of this Query reset + """return a new Query reset the 'joinpoint' of this Query reset back to the starting mapper. Subsequent generative calls will be constructed from the new joinpoint. Note that each call to join() or outerjoin() also starts from the root. - """ - q = self._clone() + """ + q = self.__no_statement("reset_joinpoint") q._joinpoint = q.mapper - q._aliases = None + if q.table not in q._get_joinable_tables(): + q._aliases_head = q._aliases_tail = mapperutil.AliasedClauses(q._from_obj, equivalents=q.mapper._equivalent_columns) + else: + q._aliases_head = q._aliases_tail = None return q - def select_from(self, from_obj): - """Set the `from_obj` parameter of the query and return the newly - resulting ``Query``. + """Set the `from_obj` parameter of the query and return the newly + resulting ``Query``. This replaces the table which this Query selects + from with the given table. + + + `from_obj` is a single table or selectable. - `from_obj` is a list of one or more tables. """ + new = self.__no_criterion('select_from') + if isinstance(from_obj, (tuple, list)): + util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.") + from_obj = from_obj[-1] - new = self._clone() - new._from_obj = list(new._from_obj) + util.to_list(from_obj) + new.__set_select_from(from_obj) return new - + def __getitem__(self, item): if isinstance(item, slice): start = item.start @@ -538,24 +845,25 @@ class Query(object): def limit(self, limit): """Apply a ``LIMIT`` to the query and return the newly resulting + ``Query``. - """ + """ return self[:limit] def offset(self, offset): """Apply an ``OFFSET`` to the query and return the newly resulting ``Query``. - """ + """ return self[offset:] def distinct(self): """Apply a ``DISTINCT`` to the query and return the newly resulting ``Query``. - """ - new = self._clone() + """ + new = self.__no_statement("distinct") new._distinct = True return new @@ -563,26 +871,39 @@ class Query(object): """Return the results represented by this ``Query`` as a list. This results in an execution of the underlying query. + """ return list(self) - - + + def from_statement(self, statement): + """Execute the given SELECT statement and return results. + + This method bypasses all internal statement compilation, and the + statement is executed without modification. + + The statement argument is either a string, a ``select()`` construct, + or a ``text()`` construct, and should return the set of columns + appropriate to the entity class represented by this ``Query``. + + Also see the ``instances()`` method. + + """ if isinstance(statement, basestring): statement = sql.text(statement) - q = self._clone() + q = self.__no_criterion('from_statement') q._statement = statement return q - + def first(self): - """Return the first result of this ``Query``. + """Return the first result of this ``Query`` or None if the result doesn't contain any row. This results in an execution of the underlying query. - """ - if self._column_aggregate is not None: + """ + if self._column_aggregate is not None: return self._col_aggregate(*self._column_aggregate) - + ret = list(self[0:1]) if len(ret) > 0: return ret[0] @@ -590,334 +911,302 @@ class Query(object): return None def one(self): - """Return the first result of this ``Query``, raising an exception if more than one row exists. + """Return the first result, raising an exception unless exactly one row exists. This results in an execution of the underlying query. - """ - if self._column_aggregate is not None: + """ + if self._column_aggregate is not None: return self._col_aggregate(*self._column_aggregate) ret = list(self[0:2]) - + if len(ret) == 1: return ret[0] elif len(ret) == 0: raise exceptions.InvalidRequestError('No rows returned for one()') else: raise exceptions.InvalidRequestError('Multiple rows returned for one()') - + def __iter__(self): - statement = self.compile() - statement.use_labels = True - if self.session.autoflush: - self.session.flush() - return self._execute_and_instances(statement) - - def _execute_and_instances(self, statement): - result = self.session.execute(statement, params=self._params, mapper=self.mapper) - try: - return iter(self.instances(result)) - finally: - result.close() + context = self._compile_context() + context.statement.use_labels = True + if self._autoflush and not self._populate_existing: + self.session._autoflush() + return self._execute_and_instances(context) - def instances(self, cursor, *mappers_or_columns, **kwargs): - """Return a list of mapped instances corresponding to the rows - in a given *cursor* (i.e. ``ResultProxy``). - - The \*mappers_or_columns and \**kwargs arguments are deprecated. - To add instances or columns to the results, use add_entity() - and add_column(). - """ + def _execute_and_instances(self, querycontext): + result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper, instance=self._refresh_instance) + return self.iterate_instances(result, querycontext=querycontext) - self.__log_debug("instances()") + def instances(self, cursor, *mappers_or_columns, **kwargs): + return list(self.iterate_instances(cursor, *mappers_or_columns, **kwargs)) + def iterate_instances(self, cursor, *mappers_or_columns, **kwargs): session = self.session - kwargs.setdefault('populate_existing', self._populate_existing) - kwargs.setdefault('version_check', self._version_check) + context = kwargs.pop('querycontext', None) + if context is None: + context = QueryContext(self) + + context.runid = _new_runid() + + entities = self._entities + [_QueryEntity.legacy_guess_type(mc) for mc in mappers_or_columns] - context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs) - - process = [] - mappers_or_columns = tuple(self._entities) + mappers_or_columns - if mappers_or_columns: - for tup in mappers_or_columns: - if isinstance(tup, tuple): - (m, alias, alias_id) = tup - clauses = self._get_entity_clauses(tup) - else: - clauses = alias = alias_id = None - m = tup - if isinstance(m, type): - m = mapper.class_mapper(m) - if isinstance(m, mapper.Mapper): - def x(m): - row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row) - appender = [] - def proc(context, row): - if not m._instance(context, row_adapter(row), appender): - appender.append(None) - process.append((proc, appender)) - x(m) - elif isinstance(m, (sql.ColumnElement, basestring)): - def y(m): - row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row) - res = [] - def proc(context, row): - res.append(row_adapter(row)[m]) - process.append((proc, res)) - y(m) - result = [] + if getattr(self, '_no_filters', False): + filter = None + single_entity = custom_rows = False else: - result = util.UniqueAppender([]) - - for row in cursor.fetchall(): - self.select_mapper._instance(context, row, result) - for proc in process: - proc[0](context, row) - - for instance in context.identity_map.values(): - context.attributes.get(('populating_mapper', instance), object_mapper(instance))._post_instance(context, instance) + single_entity = isinstance(entities[0], _PrimaryMapperEntity) and len(entities) == 1 + custom_rows = single_entity and 'append_result' in context.extension.methods + + if single_entity: + filter = util.OrderedIdentitySet + else: + filter = util.OrderedSet - # store new stuff in the identity map - for instance in context.identity_map.values(): - session._register_persistent(instance) + process = [query_entity.row_processor(self, context, single_entity) for query_entity in entities] - if mappers_or_columns: - return list(util.OrderedSet(zip(*([result] + [o[1] for o in process])))) - else: - return result.data + while True: + context.progress = util.Set() + context.partials = {} + + if self._yield_per: + fetch = cursor.fetchmany(self._yield_per) + if not fetch: + return + else: + fetch = cursor.fetchall() + + if custom_rows: + rows = [] + for row in fetch: + process[0](context, row, rows) + elif single_entity: + rows = [process[0](context, row) for row in fetch] + else: + rows = [tuple([proc(context, row) for proc in process]) for row in fetch] + + if filter: + rows = filter(rows) + + if context.refresh_instance and context.only_load_props and context.refresh_instance in context.progress: + context.refresh_instance.commit(context.only_load_props) + context.progress.remove(context.refresh_instance) + + for ii in context.progress: + context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii) + ii.commit_all() + + for ii, attrs in context.partials.items(): + context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii, only_load_props=attrs) + ii.commit(attrs) + + for row in rows: + yield row + if not self._yield_per: + break - def _get(self, key, ident=None, reload=False, lockmode=None): + def _get(self, key=None, ident=None, refresh_instance=None, lockmode=None, only_load_props=None): lockmode = lockmode or self._lockmode - if not reload and not self.mapper.always_refresh and lockmode is None: + if not self._populate_existing and not refresh_instance and not self.mapper.always_refresh and lockmode is None: try: - return self.session._get(key) + return self.session.identity_map[key] except KeyError: pass if ident is None: - ident = key[1] + if key is not None: + ident = key[1] else: ident = util.to_list(ident) - params = {} + + q = self - for i, primary_key in enumerate(self.primary_key_columns): - try: - params[primary_key._label] = ident[i] - except IndexError: - raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns])) + # dont use 'polymorphic' mapper if we are refreshing an instance + if refresh_instance and q.mapper is not q.mapper: + q = q.__reset_all(q.mapper, '_get') + + if ident is not None: + q = q.__no_criterion('get') + params = {} + (_get_clause, _get_params) = q.mapper._get_clause + q = q.filter(_get_clause) + for i, primary_key in enumerate(q.mapper.primary_key): + try: + params[_get_params[primary_key].key] = ident[i] + except IndexError: + raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in q.mapper.primary_key])) + q = q.params(params) + + if lockmode is not None: + q = q.with_lockmode(lockmode) + q = q.__get_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance) + q._order_by = None try: - q = self - if lockmode is not None: - q = q.with_lockmode(lockmode) - q = q.filter(self.select_mapper._get_clause) - q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None)) - return q.first() + # call using all() to avoid LIMIT compilation complexity + return q.all()[0] except IndexError: return None - def _should_nest(self, querycontext): - """Return True if the given statement options indicate that we - should *nest* the generated query as a subquery inside of a - larger eager-loading query. This is used with keywords like - distinct, limit and offset and the mapper defines eager loads. - """ - - return ( - len(querycontext.eager_loaders) > 0 - and self._nestable(**querycontext.select_args()) - ) - - def _nestable(self, **kwargs): - """Return true if the given statement options imply it should be nested.""" - + def _select_args(self): + return {'limit':self._limit, 'offset':self._offset, 'distinct':self._distinct, 'group_by':self._group_by or None, 'having':self._having or None} + _select_args = property(_select_args) + + def _should_nest_selectable(self): + kwargs = self._select_args return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False)) + _should_nest_selectable = property(_should_nest_selectable) def count(self, whereclause=None, params=None, **kwargs): """Apply this query's criterion to a SELECT COUNT statement. the whereclause, params and \**kwargs arguments are deprecated. use filter() and other generative methods to establish modifiers. - """ + """ q = self if whereclause is not None: q = q.filter(whereclause) if params is not None: - q = q.params(**params) + q = q.params(params) q = q._legacy_select_kwargs(**kwargs) return q._count() def _count(self): """Apply this query's criterion to a SELECT COUNT statement. - - this is the purely generative version which will become + + this is the purely generative version which will become the public method in version 0.5. + """ + return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self.mapper.primary_key)) + def _col_aggregate(self, col, func, nested_cols=None): whereclause = self._criterion - + context = QueryContext(self) - from_obj = context.from_obj - - alltables = [] - for l in [sql_util.TableFinder(x) for x in from_obj]: - alltables += l - - if self.table not in alltables: - from_obj.append(self.table) - if self._nestable(**context.select_args()): - s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count() + from_obj = self._from_obj + + if self._should_nest_selectable: + if not nested_cols: + nested_cols = [col] + s = sql.select(nested_cols, whereclause, from_obj=from_obj, **self._select_args) + s = s.alias() + s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s) else: - primary_key = self.primary_key_columns - s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args()) + s = sql.select([func(col)], whereclause, from_obj=from_obj, **self._select_args) + + if self._autoflush and not self._populate_existing: + self.session._autoflush() return self.session.scalar(s, params=self._params, mapper=self.mapper) - + def compile(self): """compiles and returns a SQL statement based on the criterion and conditions within this Query.""" - - if self._statement: - self._statement.use_labels = True - return self._statement - - whereclause = self._criterion - if whereclause is not None and (self.mapper is not self.select_mapper): - # adapt the given WHERECLAUSE to adjust instances of this query's mapped - # table to be that of our select_table, - # which may be the "polymorphic" selectable used by our mapper. - sql_util.ClauseAdapter(self.table).traverse(whereclause, stop_on=util.Set([self.table])) - - # if extra entities, adapt the criterion to those as well - for m in self._entities: - if isinstance(m, type): - m = mapper.class_mapper(m) - if isinstance(m, mapper.Mapper): - table = m.select_table - sql_util.ClauseAdapter(m.select_table).traverse(whereclause, stop_on=util.Set([m.select_table])) - - # get/create query context. get the ultimate compile arguments - # from there - context = QueryContext(self) - order_by = context.order_by - from_obj = context.from_obj - lockmode = context.lockmode - if order_by is False: - order_by = self.mapper.order_by - if order_by is False: - if self.table.default_order_by() is not None: - order_by = self.table.default_order_by() + return self._compile_context().statement - try: - for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[lockmode] - except KeyError: - raise exceptions.ArgumentError("Unknown lockmode '%s'" % lockmode) + def _compile_context(self): - # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so - # that we only load the appropriate types - if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None: - whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_(*[m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()])) - - alltables = [] - for l in [sql_util.TableFinder(x) for x in from_obj]: - alltables += l - - if self.table not in alltables: - from_obj.append(self.table) - - if self._should_nest(context): - # if theres an order by, add those columns to the column list - # of the "rowcount" query we're going to make - if order_by: - order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []] - cf = sql_util.ColumnFinder() - for o in order_by: - cf.traverse(o) - else: - cf = [] - - s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args()) - if order_by: - s2 = s2.order_by(*util.to_list(order_by)) - s3 = s2.alias('tbl_row_count') - crit = s3.primary_key==self.primary_key_columns - statement = sql.select([], crit, use_labels=True, for_update=for_update) - # now for the order by, convert the columns to their corresponding columns - # in the "rowcount" query, and tack that new order by onto the "rowcount" query - if order_by: - statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by)) - else: - statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args()) - if order_by: - statement.append_order_by(*util.to_list(order_by)) - - # for a DISTINCT query, you need the columns explicitly specified in order - # to use it in "order_by". ensure they are in the column criterion (particularly oid). - # TODO: this should be done at the SQL level not the mapper level - # TODO: need test coverage for this - if context.distinct and order_by: - [statement.append_column(c) for c in util.to_list(order_by)] + context = QueryContext(self) - context.statement = statement - - # give all the attached properties a chance to modify the query - # TODO: doing this off the select_mapper. if its the polymorphic mapper, then - # it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads) - for value in self.select_mapper.iterate_properties: - value.setup(context) - - # additional entities/columns, add those to selection criterion - for tup in self._entities: - (m, alias, alias_id) = tup - clauses = self._get_entity_clauses(tup) - if isinstance(m, mapper.Mapper): - for value in m.iterate_properties: - value.setup(context, parentclauses=clauses) - elif isinstance(m, sql.ColumnElement): - if clauses is not None: - m = clauses.adapt_clause(m) - statement.append_column(m) - - return statement + if self._statement: + self._statement.use_labels = True + context.statement = self._statement + return context - def _get_entity_clauses(self, m): - """for tuples added via add_entity() or add_column(), attempt to locate - an AliasedClauses object which should be used to formulate the query as well - as to process result rows.""" - (m, alias, alias_id) = m - if alias is not None: - return alias - if alias_id is not None: + from_obj = self._from_obj + adapter = self._aliases_head + + if self._lockmode: try: - return self._alias_ids[alias_id] + for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode] except KeyError: - raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id) - if isinstance(m, type): - m = mapper.class_mapper(m) - if isinstance(m, mapper.Mapper): - l = self._alias_ids.get(m) - if l: - if len(l) > 1: - raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id= to query.join()/query.add_entity()" % str(m)) - else: - return l[0] + raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode) + else: + for_update = False + + context.from_clause = from_obj + context.whereclause = self._criterion + context.order_by = self._order_by + + for entity in self._entities: + entity.setup_context(self, context) + + if self._eager_loaders and self._should_nest_selectable: + # eager loaders are present, and the SELECT has limiting criterion + # produce a "wrapped" selectable. + + if context.order_by: + context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []] + if adapter: + context.order_by = adapter.adapt_list(context.order_by) + # locate all embedded Column clauses so they can be added to the + # "inner" select statement where they'll be available to the enclosing + # statement's "order by" + # TODO: this likely doesn't work with very involved ORDER BY expressions, + # such as those including subqueries + order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by])) else: - return None - elif isinstance(m, sql.ColumnElement): - aliases = [] - for table in sql_util.TableFinder(m, check_columns=True): - for a in self._alias_ids.get(table, []): - aliases.append(a) - if len(aliases) > 1: - raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id= to query.join()/query.add_column()" % str(m)) - elif len(aliases) == 1: - return aliases[0] + context.order_by = None + order_by_col_expr = [] + + if adapter: + context.primary_columns = adapter.adapt_list(context.primary_columns) + + inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=context.order_by, **self._select_args).alias() + local_adapter = sql_util.ClauseAdapter(inner) + + context.row_adapter = mapperutil.create_row_adapter(inner, equivalent_columns=self.mapper._equivalent_columns) + + statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=True) + + if context.eager_joins: + eager_joins = local_adapter.traverse(context.eager_joins) + statement.append_from(eager_joins) + + if context.order_by: + statement.append_order_by(*local_adapter.copy_and_process(context.order_by)) + + statement.append_order_by(*context.eager_order_by) + else: + if context.order_by: + context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []] + if adapter: + context.order_by = adapter.adapt_list(context.order_by) else: - return None + context.order_by = None + if adapter: + context.primary_columns = adapter.adapt_list(context.primary_columns) + context.row_adapter = mapperutil.create_row_adapter(adapter.alias, equivalent_columns=self.mapper._equivalent_columns) + + if self._distinct and context.order_by: + order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by])) + context.primary_columns += order_by_col_expr + + statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=context.order_by, **self._select_args) + + if context.eager_joins: + if adapter: + context.eager_joins = adapter.adapt_clause(context.eager_joins) + statement.append_from(context.eager_joins) + + if context.eager_order_by: + if adapter: + context.eager_order_by = adapter.adapt_list(context.eager_order_by) + statement.append_order_by(*context.eager_order_by) + + # polymorphic mappers which have concrete tables in their hierarchy usually + # require row aliasing unconditionally. + if not context.row_adapter and self.mapper._requires_row_aliasing: + context.row_adapter = mapperutil.create_row_adapter(self.table, equivalent_columns=self.mapper._equivalent_columns) + + context.statement = statement + + return context + def __log_debug(self, msg): self.logger.debug(msg) @@ -926,41 +1215,90 @@ class Query(object): # DEPRECATED LAND ! - def list(self): + def _generative_col_aggregate(self, col, func): + """apply the given aggregate function to the query and return the newly + resulting ``Query``. (deprecated) + """ + if self._column_aggregate is not None: + raise exceptions.InvalidRequestError("Query already contains an aggregate column or function") + q = self.__no_statement("aggregate") + q._column_aggregate = (col, func) + return q + + def apply_min(self, col): + """apply the SQL ``min()`` function against the given column to the + query and return the newly resulting ``Query``. + + DEPRECATED. + """ + return self._generative_col_aggregate(col, sql.func.min) + + def apply_max(self, col): + """apply the SQL ``max()`` function against the given column to the + query and return the newly resulting ``Query``. + + DEPRECATED. + """ + return self._generative_col_aggregate(col, sql.func.max) + + def apply_sum(self, col): + """apply the SQL ``sum()`` function against the given column to the + query and return the newly resulting ``Query``. + + DEPRECATED. + """ + return self._generative_col_aggregate(col, sql.func.sum) + + def apply_avg(self, col): + """apply the SQL ``avg()`` function against the given column to the + query and return the newly resulting ``Query``. + + DEPRECATED. + """ + return self._generative_col_aggregate(col, sql.func.avg) + + def list(self): #pragma: no cover """DEPRECATED. use all()""" return list(self) - def scalar(self): + def scalar(self): #pragma: no cover """DEPRECATED. use first()""" return self.first() - def _legacy_filter_by(self, *args, **kwargs): + def _legacy_filter_by(self, *args, **kwargs): #pragma: no cover return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint)) - def count_by(self, *args, **params): + def count_by(self, *args, **params): #pragma: no cover """DEPRECATED. use query.filter_by(\**params).count()""" return self.count(self.join_by(*args, **params)) - def select_whereclause(self, whereclause=None, params=None, **kwargs): + def select_whereclause(self, whereclause=None, params=None, **kwargs): #pragma: no cover """DEPRECATED. use query.filter(whereclause).all()""" q = self.filter(whereclause)._legacy_select_kwargs(**kwargs) if params is not None: - q = q.params(**params) + q = q.params(params) return list(q) - - def _legacy_select_kwargs(self, **kwargs): + + def _legacy_select_from(self, from_obj): + q = self._clone() + if len(from_obj) > 1: + raise exceptions.ArgumentError("Multiple-entry from_obj parameter no longer supported") + q._from_obj = from_obj[0] + return q + + def _legacy_select_kwargs(self, **kwargs): #pragma: no cover q = self if "order_by" in kwargs and kwargs['order_by']: q = q.order_by(kwargs['order_by']) if "group_by" in kwargs: q = q.group_by(kwargs['group_by']) if "from_obj" in kwargs: - q = q.select_from(kwargs['from_obj']) + q = q._legacy_select_from(kwargs['from_obj']) if "lockmode" in kwargs: q = q.with_lockmode(kwargs['lockmode']) if "distinct" in kwargs: @@ -972,89 +1310,84 @@ class Query(object): return q - def get_by(self, *args, **params): + def get_by(self, *args, **params): #pragma: no cover """DEPRECATED. use query.filter_by(\**params).first()""" ret = self._extension.get_by(self, *args, **params) - if ret is not mapper.EXT_PASS: + if ret is not mapper.EXT_CONTINUE: return ret return self._legacy_filter_by(*args, **params).first() - def select_by(self, *args, **params): + def select_by(self, *args, **params): #pragma: no cover """DEPRECATED. use use query.filter_by(\**params).all().""" ret = self._extension.select_by(self, *args, **params) - if ret is not mapper.EXT_PASS: + if ret is not mapper.EXT_CONTINUE: return ret return self._legacy_filter_by(*args, **params).list() - def join_by(self, *args, **params): + def join_by(self, *args, **params): #pragma: no cover """DEPRECATED. use join() to construct joins based on attribute names.""" return self._legacy_join_by(args, params, start=self._joinpoint) - def _build_select(self, arg=None, params=None, **kwargs): + def _build_select(self, arg=None, params=None, **kwargs): #pragma: no cover if isinstance(arg, sql.FromClause) and arg.supports_execution(): return self.from_statement(arg) - else: + elif arg is not None: return self.filter(arg)._legacy_select_kwargs(**kwargs) + else: + return self._legacy_select_kwargs(**kwargs) - def selectfirst(self, arg=None, **kwargs): + def selectfirst(self, arg=None, **kwargs): #pragma: no cover """DEPRECATED. use query.filter(whereclause).first()""" return self._build_select(arg, **kwargs).first() - def selectone(self, arg=None, **kwargs): + def selectone(self, arg=None, **kwargs): #pragma: no cover """DEPRECATED. use query.filter(whereclause).one()""" return self._build_select(arg, **kwargs).one() - def select(self, arg=None, **kwargs): + def select(self, arg=None, **kwargs): #pragma: no cover """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()""" ret = self._extension.select(self, arg=arg, **kwargs) - if ret is not mapper.EXT_PASS: + if ret is not mapper.EXT_CONTINUE: return ret return self._build_select(arg, **kwargs).all() - def execute(self, clauseelement, params=None, *args, **kwargs): + def execute(self, clauseelement, params=None, *args, **kwargs): #pragma: no cover """DEPRECATED. use query.from_statement().all()""" return self._select_statement(clauseelement, params, **kwargs) - def select_statement(self, statement, **params): + def select_statement(self, statement, **params): #pragma: no cover """DEPRECATED. Use query.from_statement(statement)""" - + return self._select_statement(statement, params) - def select_text(self, text, **params): + def select_text(self, text, **params): #pragma: no cover """DEPRECATED. Use query.from_statement(statement)""" return self._select_statement(text, params) - def _select_statement(self, statement, params=None, **kwargs): + def _select_statement(self, statement, params=None, **kwargs): #pragma: no cover q = self.from_statement(statement) if params is not None: - q = q.params(**params) - q._select_context_options(**kwargs) + q = q.params(params) + q.__get_options(**kwargs) return list(q) - def _select_context_options(self, populate_existing=None, version_check=None): - if populate_existing is not None: - self._populate_existing = populate_existing - if version_check is not None: - self._version_check = version_check - return self - - def join_to(self, key): + def join_to(self, key): #pragma: no cover """DEPRECATED. use join() to create joins based on property names.""" [keys, p] = self._locate_prop(key) return self.join_via(keys) - def join_via(self, keys): + def join_via(self, keys): #pragma: no cover """DEPRECATED. use join() to create joins based on property names.""" mapper = self._joinpoint @@ -1062,14 +1395,14 @@ class Query(object): for key in keys: prop = mapper.get_property(key, resolve_synonyms=True) if clause is None: - clause = prop.get_join(mapper) + clause = prop._get_join(mapper) else: - clause &= prop.get_join(mapper) + clause &= prop._get_join(mapper) mapper = prop.mapper return clause - def _legacy_join_by(self, args, params, start=None): + def _legacy_join_by(self, args, params, start=None): #pragma: no cover import properties clause = None @@ -1082,16 +1415,16 @@ class Query(object): for key, value in params.iteritems(): (keys, prop) = self._locate_prop(key, start=start) if isinstance(prop, properties.PropertyLoader): - c = prop.compare(operator.eq, value) & self.join_via(keys[:-1]) + c = prop.compare(operators.eq, value) & self.join_via(keys[:-1]) else: - c = prop.compare(operator.eq, value) & self.join_via(keys) + c = prop.compare(operators.eq, value) & self.join_via(keys) if clause is None: clause = c else: clause &= c return clause - def _locate_prop(self, key, start=None): + def _locate_prop(self, key, start=None): #pragma: no cover import properties keys = [] seen = util.Set() @@ -1099,7 +1432,7 @@ class Query(object): if mapper_ in seen: return None seen.add(mapper_) - + prop = mapper_.get_property(key, resolve_synonyms=True, raiseerr=False) if prop is not None: if isinstance(prop, properties.PropertyLoader): @@ -1120,99 +1453,231 @@ class Query(object): raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key) return [keys, p] - def selectfirst_by(self, *args, **params): + def selectfirst_by(self, *args, **params): #pragma: no cover """DEPRECATED. Use query.filter_by(\**kwargs).first()""" return self._legacy_filter_by(*args, **params).first() - def selectone_by(self, *args, **params): + def selectone_by(self, *args, **params): #pragma: no cover """DEPRECATED. Use query.filter_by(\**kwargs).one()""" return self._legacy_filter_by(*args, **params).one() + for deprecated_method in ('list', 'scalar', 'count_by', + 'select_whereclause', 'get_by', 'select_by', + 'join_by', 'selectfirst', 'selectone', 'select', + 'execute', 'select_statement', 'select_text', + 'join_to', 'join_via', 'selectfirst_by', + 'selectone_by', 'apply_max', 'apply_min', + 'apply_avg', 'apply_sum'): + locals()[deprecated_method] = \ + util.deprecated(None, False)(locals()[deprecated_method]) + +class _QueryEntity(object): + """represent an entity column returned within a Query result.""" + + def legacy_guess_type(self, e): + if isinstance(e, type): + return _MapperEntity(mapper=mapper.class_mapper(e)) + elif isinstance(e, mapper.Mapper): + return _MapperEntity(mapper=e) + else: + return _ColumnEntity(column=e) + legacy_guess_type=classmethod(legacy_guess_type) +class _MapperEntity(_QueryEntity): + """entity column corresponding to mapped ORM instances.""" + + def __init__(self, mapper, alias=None, id=None): + self.mapper = mapper + self.alias = alias + self.alias_id = id + + def _get_entity_clauses(self, query): + if self.alias: + return self.alias + elif self.alias_id: + try: + return query._alias_ids[self.alias_id][0] + except KeyError: + raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id) -Query.logger = logging.class_logger(Query) - -class QueryContext(OperationContext): - """Created within the ``Query.compile()`` method to store and - share state among all the Mappers and MapperProperty objects used - in a query construction. - """ - - def __init__(self, query): - self.query = query - self.order_by = query._order_by - self.group_by = query._group_by - self.from_obj = query._from_obj - self.lockmode = query._lockmode - self.distinct = query._distinct - self.limit = query._limit - self.offset = query._offset - self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders]) - self.statement = None - super(QueryContext, self).__init__(query.mapper, query._with_options) + l = query._alias_ids.get(self.mapper) + if l: + if len(l) > 1: + raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id= to query.join()/query.add_entity()" % str(self.mapper)) + return l[0] + else: + return None + + def row_processor(self, query, context, single_entity): + clauses = self._get_entity_clauses(query) + if clauses: + def proc(context, row): + return self.mapper._instance(context, clauses.row_decorator(row), None) + else: + def proc(context, row): + return self.mapper._instance(context, row, None) + + return proc + + def setup_context(self, query, context): + clauses = self._get_entity_clauses(query) + for value in self.mapper.iterate_properties: + context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses) - def select_args(self): - """Return a dictionary of attributes from this - ``QueryContext`` that can be applied to a ``sql.Select`` - statement. - """ - return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None} + def __str__(self): + return str(self.mapper) + +class _PrimaryMapperEntity(_MapperEntity): + """entity column corresponding to the 'primary' (first) mapped ORM instance.""" + + def row_processor(self, query, context, single_entity): + if single_entity and 'append_result' in context.extension.methods: + def main(context, row, result): + if context.row_adapter: + row = context.row_adapter(row) + self.mapper._instance(context, row, result, + extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance + ) + elif context.row_adapter: + def main(context, row): + return self.mapper._instance(context, context.row_adapter(row), None, + extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance + ) + else: + def main(context, row): + return self.mapper._instance(context, row, None, + extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance + ) + + return main - def accept_option(self, opt): - """Accept a ``MapperOption`` which will process (modify) the - state of this ``QueryContext``. - """ + def setup_context(self, query, context): + # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so + # that we only load the appropriate types + if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None: + context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()])) + + if context.order_by is False: + if self.mapper.order_by: + context.order_by = self.mapper.order_by + elif context.from_clause.default_order_by(): + context.order_by = context.from_clause.default_order_by() + + for value in self.mapper._iterate_polymorphic_properties(query._with_polymorphic, context.from_clause): + if query._only_load_props and value.key not in query._only_load_props: + continue + context.exec_with_path(self.mapper, value.key, value.setup, context, only_load_props=query._only_load_props) - opt.process_query_context(self) +class _ColumnEntity(_QueryEntity): + """entity column corresponding to Table or selectable columns.""" + def __init__(self, column, id): + if isinstance(column, basestring): + column = sql.literal_column(column) + + if column and isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'): + column = column.label(None) + self.column = column + self.alias_id = id -class SelectionContext(OperationContext): - """Created within the ``query.instances()`` method to store and share - state among all the Mappers and MapperProperty objects used in a - load operation. + def __resolve_expr_against_query_aliases(self, query, expr, context): + if not query._alias_ids: + return expr + + if ('_ColumnEntity', expr) in context.attributes: + return context.attributes[('_ColumnEntity', expr)] + + if self.alias_id: + try: + aliases = query._alias_ids[self.alias_id][0] + except KeyError: + raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id) - SelectionContext contains these attributes: + def _locate_aliased(element): + if element in query._alias_ids: + return aliases + else: + def _locate_aliased(element): + if element in query._alias_ids: + aliases = query._alias_ids[element] + if len(aliases) > 1: + raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id= to query.join()/query.add_column(), or use the aliased() function to use explicit class aliases." % expr) + return aliases[0] + return None - mapper - The Mapper which originated the instances() call. + class Adapter(visitors.ClauseVisitor): + def before_clone(self, element): + if isinstance(element, expression.FromClause): + alias = _locate_aliased(element) + if alias: + return alias.alias + + if hasattr(element, 'table'): + alias = _locate_aliased(element.table) + if alias: + return alias.aliased_column(element) - session - The Session that is relevant to the instances call. + return None - identity_map - A dictionary which stores newly created instances that have not - yet been added as persistent to the Session. + context.attributes[('_ColumnEntity', expr)] = ret = Adapter().traverse(expr, clone=True) + return ret + + def row_processor(self, query, context, single_entity): + column = self.__resolve_expr_against_query_aliases(query, self.column, context) + def proc(context, row): + return row[column] + return proc + + def setup_context(self, query, context): + column = self.__resolve_expr_against_query_aliases(query, self.column, context) + context.secondary_columns.append(column) + + def __str__(self): + return str(self.column) - attributes - A dictionary to store arbitrary data; mappers, strategies, and - options all store various state information here in order - to communicate with each other and to themselves. - + +Query.logger = logging.class_logger(Query) - populate_existing - Indicates if its OK to overwrite the attributes of instances - that were already in the Session. +class QueryContext(object): + def __init__(self, query): + self.query = query + self.mapper = query.mapper + self.session = query.session + self.extension = query._extension + self.statement = None + self.row_adapter = None + self.populate_existing = query._populate_existing + self.version_check = query._version_check + self.only_load_props = query._only_load_props + self.refresh_instance = query._refresh_instance + self.path = () + self.primary_columns = [] + self.secondary_columns = [] + self.eager_order_by = [] + self.eager_joins = None + self.options = query._with_options + self.attributes = query._attributes.copy() + + def exec_with_path(self, mapper, propkey, fn, *args, **kwargs): + oldpath = self.path + self.path += (mapper.base_mapper, propkey) + try: + return fn(*args, **kwargs) + finally: + self.path = oldpath - version_check - Indicates if mappers that have version_id columns should verify - that instances existing already within the Session should have - this attribute compared to the freshly loaded value. - """ - def __init__(self, mapper, session, extension, **kwargs): - self.populate_existing = kwargs.pop('populate_existing', False) - self.version_check = kwargs.pop('version_check', False) - self.session = session - self.extension = extension - self.identity_map = {} - self.stack = LoaderStack() - super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs) - def accept_option(self, opt): - """Accept a MapperOption which will process (modify) the state - of this SelectionContext. - """ +_runid = 1L +_id_lock = util.threading.Lock() - opt.process_selection_context(self) +def _new_runid(): + global _runid + _id_lock.acquire() + try: + _runid += 1 + return _runid + finally: + _id_lock.release() diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py new file mode 100644 index 0000000000..479b2f7374 --- /dev/null +++ b/lib/sqlalchemy/orm/scoping.py @@ -0,0 +1,174 @@ +from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs +from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, class_mapper +from sqlalchemy.orm.session import Session +from sqlalchemy import exceptions +import types + +__all__ = ['ScopedSession'] + + +class ScopedSession(object): + """Provides thread-local management of Sessions. + + Usage:: + + Session = scoped_session(sessionmaker(autoflush=True)) + + To map classes so that new instances are saved in the current + Session automatically, as well as to provide session-aware + class attributes such as "query": + + mapper = Session.mapper + mapper(Class, table, ...) + + """ + + def __init__(self, session_factory, scopefunc=None): + self.session_factory = session_factory + self.registry = ScopedRegistry(session_factory, scopefunc) + self.extension = _ScopedExt(self) + + def __call__(self, **kwargs): + if kwargs: + scope = kwargs.pop('scope', False) + if scope is not None: + if self.registry.has(): + raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.") + else: + sess = self.session_factory(**kwargs) + self.registry.set(sess) + return sess + else: + return self.session_factory(**kwargs) + else: + return self.registry() + + def remove(self): + if self.registry.has(): + self.registry().close() + self.registry.clear() + + def mapper(self, *args, **kwargs): + """return a mapper() function which associates this ScopedSession with the Mapper.""" + + from sqlalchemy.orm import mapper + + extension_args = dict([(arg,kwargs.pop(arg)) + for arg in get_cls_kwargs(_ScopedExt) + if arg in kwargs]) + + kwargs['extension'] = extension = to_list(kwargs.get('extension', [])) + if extension_args: + extension.append(self.extension.configure(**extension_args)) + else: + extension.append(self.extension) + return mapper(*args, **kwargs) + + def configure(self, **kwargs): + """reconfigure the sessionmaker used by this ScopedSession.""" + + self.session_factory.configure(**kwargs) + + def query_property(self): + """return a class property which produces a `Query` object against the + class when called. + + e.g.:: + Session = scoped_session(sessionmaker()) + + class MyClass(object): + query = Session.query_property() + + # after mappers are defined + result = MyClass.query.filter(MyClass.name=='foo').all() + + """ + + class query(object): + def __get__(s, instance, owner): + mapper = class_mapper(owner, raiseerror=False) + if mapper: + return self.registry().query(mapper) + else: + return None + return query() + +def instrument(name): + def do(self, *args, **kwargs): + return getattr(self.registry(), name)(*args, **kwargs) + return do +for meth in ('get', 'load', 'close', 'save', 'commit', 'update', 'save_or_update', 'flush', 'query', 'delete', 'merge', 'clear', 'refresh', 'expire', 'expunge', 'rollback', 'begin', 'begin_nested', 'connection', 'execute', 'scalar', 'get_bind', 'is_modified', '__contains__', '__iter__'): + setattr(ScopedSession, meth, instrument(meth)) + +def makeprop(name): + def set(self, attr): + setattr(self.registry(), name, attr) + def get(self): + return getattr(self.registry(), name) + return property(get, set) +for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map'): + setattr(ScopedSession, prop, makeprop(prop)) + +def clslevel(name): + def do(cls, *args,**kwargs): + return getattr(Session, name)(*args, **kwargs) + return classmethod(do) +for prop in ('close_all','object_session', 'identity_key'): + setattr(ScopedSession, prop, clslevel(prop)) + +class _ScopedExt(MapperExtension): + def __init__(self, context, validate=False, save_on_init=True): + self.context = context + self.validate = validate + self.save_on_init = save_on_init + + def validating(self): + return _ScopedExt(self.context, validate=True) + + def configure(self, **kwargs): + return _ScopedExt(self.context, **kwargs) + + def get_session(self): + return self.context.registry() + + def instrument_class(self, mapper, class_): + class query(object): + def __getattr__(s, key): + return getattr(self.context.registry().query(class_), key) + def __call__(s): + return self.context.registry().query(class_) + + if not 'query' in class_.__dict__: + class_.query = query() + + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): + if self.save_on_init: + entity_name = kwargs.pop('_sa_entity_name', None) + session = kwargs.pop('_sa_session', None) + if not isinstance(oldinit, types.MethodType): + for key, value in kwargs.items(): + if self.validate: + if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False): + raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) + setattr(instance, key, value) + kwargs.clear() + if self.save_on_init: + session = session or self.context.registry() + session._save_impl(instance, entity_name=entity_name) + return EXT_CONTINUE + + def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): + object_session(instance).expunge(instance) + return EXT_CONTINUE + + def dispose_class(self, mapper, class_): + if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): + if class_.__init__._oldinit is not None: + class_.__init__ = class_.__init__._oldinit + else: + delattr(class_, '__init__') + if hasattr(class_, 'query'): + delattr(class_, 'query') + + + diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 6b5c4a0725..57f23ace29 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1,118 +1,313 @@ -# objectstore.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# session.py +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import weakref +"""Provides the Session class and related utilities.""" + +import weakref from sqlalchemy import util, exceptions, sql, engine -from sqlalchemy.orm import unitofwork, query, util as mapperutil +from sqlalchemy.orm import unitofwork, query, attributes, util as mapperutil from sqlalchemy.orm.mapper import object_mapper as _object_mapper from sqlalchemy.orm.mapper import class_mapper as _class_mapper +from sqlalchemy.orm.mapper import Mapper + + +__all__ = ['Session', 'SessionTransaction', 'SessionExtension'] + +def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs): + """Generate a custom-configured [sqlalchemy.orm.session#Session] class. + + The returned object is a subclass of ``Session``, which, when instantiated with no + arguments, uses the keyword arguments configured here as its constructor arguments. + + It is intended that the `sessionmaker()` function be called within the global scope + of an application, and the returned class be made available to the rest of the + application as the single class used to instantiate sessions. + + e.g.:: + + # global scope + Session = sessionmaker(autoflush=False) + + # later, in a local scope, create and use a session: + sess = Session() + + Any keyword arguments sent to the constructor itself will override the "configured" + keywords:: + + Session = sessionmaker() + + # bind an individual session to a connection + sess = Session(bind=connection) + + The class also includes a special classmethod ``configure()``, which allows + additional configurational options to take place after the custom ``Session`` + class has been generated. This is useful particularly for defining the + specific ``Engine`` (or engines) to which new instances of ``Session`` + should be bound:: + + Session = sessionmaker() + Session.configure(bind=create_engine('sqlite:///foo.db')) + + sess = Session() + + The function features a single keyword argument of its own, `class_`, which + may be used to specify an alternate class other than ``sqlalchemy.orm.session.Session`` + which should be used by the returned class. All other keyword arguments sent to + `sessionmaker()` are passed through to the instantiated `Session()` object. + """ + + kwargs['bind'] = bind + kwargs['autoflush'] = autoflush + kwargs['transactional'] = transactional + + if class_ is None: + class_ = Session + + class Sess(class_): + def __init__(self, **local_kwargs): + for k in kwargs: + local_kwargs.setdefault(k, kwargs[k]) + super(Sess, self).__init__(**local_kwargs) + + def configure(self, **new_kwargs): + """(re)configure the arguments for this sessionmaker. + + e.g. + Session = sessionmaker() + Session.configure(bind=create_engine('sqlite://')) + """ + + kwargs.update(new_kwargs) + configure = classmethod(configure) + + return Sess + +class SessionExtension(object): + """An extension hook object for Sessions. Subclasses may be installed into a Session + (or sessionmaker) using the ``extension`` keyword argument. + """ + + def before_commit(self, session): + """Execute right before commit is called. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def after_commit(self, session): + """Execute after a commit has occured. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def after_rollback(self, session): + """Execute after a rollback has occured. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def before_flush(self, session, flush_context, instances): + """Execute before flush process has started. + + `instances` is an optional list of objects which were passed to the ``flush()`` + method. + """ + + def after_flush(self, session, flush_context): + """Execute after flush has completed, but before commit has been called. + + Note that the session's state is still in pre-flush, i.e. 'new', 'dirty', + and 'deleted' lists still show pre-flush state as well as the history + settings on instance attributes.""" + + def after_flush_postexec(self, session, flush_context): + """Execute after flush has completed, and after the post-exec state occurs. + + This will be when the 'new', 'dirty', and 'deleted' lists are in their final + state. An actual commit() may or may not have occured, depending on whether or not + the flush started its own transaction or participated in a larger transaction. + """ + + def after_begin(self, session, transaction, connection): + """Execute after a transaction is begun on a connection + + `transaction` is the SessionTransaction. This method is called after an + engine level transaction is begun on a connection. + """ class SessionTransaction(object): """Represents a Session-level Transaction. - This corresponds to one or more sqlalchemy.engine.Transaction - instances behind the scenes, with one Transaction per Engine in + This corresponds to one or more [sqlalchemy.engine#Transaction] + instances behind the scenes, with one ``Transaction`` per ``Engine`` in use. - The SessionTransaction object is **not** threadsafe. + Direct usage of ``SessionTransaction`` is not necessary as of + SQLAlchemy 0.4; use the ``begin()`` and ``commit()`` methods on + ``Session`` itself. + + The ``SessionTransaction`` object is **not** threadsafe. """ def __init__(self, session, parent=None, autoflush=True, nested=False): self.session = session - self.__connections = {} - self.__parent = parent + self._connections = {} + self._parent = parent self.autoflush = autoflush self.nested = nested + self._active = True + self._prepared = False - def connection(self, mapper_or_class, entity_name=None, **kwargs): - if isinstance(mapper_or_class, type): - mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name) - engine = self.session.get_bind(mapper_or_class, **kwargs) + is_active = property(lambda s: s.session is not None and s._active) + + def _assert_is_active(self): + self._assert_is_open() + if not self._active: + raise exceptions.InvalidRequestError("The transaction is inactive due to a rollback in a subtransaction and should be closed") + + def _assert_is_open(self): + if self.session is None: + raise exceptions.InvalidRequestError("The transaction is closed") + + def connection(self, bindkey, **kwargs): + self._assert_is_active() + engine = self.session.get_bind(bindkey, **kwargs) return self.get_or_add(engine) def _begin(self, **kwargs): + self._assert_is_active() return SessionTransaction(self.session, self, **kwargs) + def _iterate_parents(self, upto=None): + if self._parent is upto: + return (self,) + else: + if self._parent is None: + raise exceptions.InvalidRequestError("Transaction %s is not on the active transaction list" % upto) + return (self,) + self._parent._iterate_parents(upto) + def add(self, bind): - if self.__parent is not None: - return self.__parent.add(bind) - - if self.__connections.has_key(bind.engine): + self._assert_is_active() + if self._parent is not None and not self.nested: + return self._parent.add(bind) + + if bind.engine in self._connections: raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or "")) return self.get_or_add(bind) - def _connection_dict(self): - if self.__parent is not None and not self.nested: - return self.__parent._connection_dict() - else: - return self.__connections - def get_or_add(self, bind): - if self.__parent is not None: + self._assert_is_active() + + if bind in self._connections: + return self._connections[bind][0] + + if self._parent is not None: + conn = self._parent.get_or_add(bind) if not self.nested: - return self.__parent.get_or_add(bind) - - if self.__connections.has_key(bind): - return self.__connections[bind][0] - - if bind in self.__parent._connection_dict(): - (conn, trans, autoclose) = self.__parent.__connections[bind] - self.__connections[conn] = self.__connections[bind.engine] = (conn, conn.begin_nested(), autoclose) return conn - elif self.__connections.has_key(bind): - return self.__connections[bind][0] - - if not isinstance(bind, engine.Connection): - e = bind - c = bind.contextual_connect() else: - e = bind.engine - c = bind - if e in self.__connections: - raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") - if self.nested: - trans = c.begin_nested() - elif self.session.twophase: - trans = c.begin_twophase() - else: - trans = c.begin() - self.__connections[c] = self.__connections[e] = (c, trans, c is not bind) - return self.__connections[c][0] + if isinstance(bind, engine.Connection): + conn = bind + if conn.engine in self._connections: + raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") + else: + conn = bind.contextual_connect() - def commit(self): - if self.__parent is not None and not self.nested: - return self.__parent + if self.session.twophase and self._parent is None: + transaction = conn.begin_twophase() + elif self.nested: + transaction = conn.begin_nested() + else: + transaction = conn.begin() + + self._connections[conn] = self._connections[conn.engine] = (conn, transaction, conn is not bind) + if self.session.extension is not None: + self.session.extension.after_begin(self.session, self, conn) + return conn + + def prepare(self): + if self._parent is not None or not self.session.twophase: + raise exceptions.InvalidRequestError("Only root two phase transactions of can be prepared") + self._prepare_impl() + + def _prepare_impl(self): + self._assert_is_active() + if self.session.extension is not None and (self._parent is None or self.nested): + self.session.extension.before_commit(self.session) + + if self.session.transaction is not self: + for subtransaction in self.session.transaction._iterate_parents(upto=self): + subtransaction.commit() + if self.autoflush: self.session.flush() + + if self._parent is None and self.session.twophase: + try: + for t in util.Set(self._connections.values()): + t[1].prepare() + except: + self.rollback() + raise + + self._deactivate() + self._prepared = True + + def commit(self): + self._assert_is_open() + if not self._prepared: + self._prepare_impl() + + if self._parent is None or self.nested: + for t in util.Set(self._connections.values()): + t[1].commit() - if self.session.twophase: - for t in util.Set(self.__connections.values()): - t[1].prepare() + if self.session.extension is not None: + self.session.extension.after_commit(self.session) - for t in util.Set(self.__connections.values()): - t[1].commit() self.close() - return self.__parent + return self._parent def rollback(self): - if self.__parent is not None and not self.nested: - return self.__parent.rollback() - for t in util.Set(self.__connections.values()): - t[1].rollback() - self.close() - return self.__parent + self._assert_is_open() + if self.session.transaction is not self: + for subtransaction in self.session.transaction._iterate_parents(upto=self): + subtransaction.close() + + if self.is_active or self._prepared: + for transaction in self._iterate_parents(): + if transaction._parent is None or transaction.nested: + transaction._rollback_impl() + transaction._deactivate() + break + else: + transaction._deactivate() + + self.close() + return self._parent + + def _rollback_impl(self): + for t in util.Set(self._connections.values()): + t[1].rollback() + + if self.session.extension is not None: + self.session.extension.after_rollback(self.session) + + def _deactivate(self): + self._active = False + def close(self): - if self.__parent is not None: - return - for t in util.Set(self.__connections.values()): - if t[2]: - t[0].close() - self.session.transaction = None + self.session.transaction = self._parent + if self._parent is None: + for connection, transaction, autoclose in util.Set(self._connections.values()): + if autoclose: + connection.close() + else: + transaction.close() + self._deactivate() + self.session = None + self._connections = None def __enter__(self): return self @@ -121,43 +316,182 @@ class SessionTransaction(object): if self.session.transaction is None: return if type is None: - self.commit() + try: + self.commit() + except: + self.rollback() + raise else: self.rollback() class Session(object): - """Encapsulates a set of objects being operated upon within an - object-relational operation. + """Encapsulates a set of objects being operated upon within an object-relational operation. + + The Session is the front end to SQLAlchemy's **Unit of Work** implementation. The concept + behind Unit of Work is to track modifications to a field of objects, and then be able to + flush those changes to the database in a single operation. - The Session object is **not** threadsafe. For thread-management - of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module. + SQLAlchemy's unit of work includes these functions: + + * The ability to track in-memory changes on scalar- and collection-based object + attributes, such that database persistence operations can be assembled based on those + changes. + + * The ability to organize individual SQL queries and population of newly generated + primary and foreign key-holding attributes during a persist operation such that + referential integrity is maintained at all times. + + * The ability to maintain insert ordering against the order in which new instances were + added to the session. + + * an Identity Map, which is a dictionary keying instances to their unique primary key + identity. This ensures that only one copy of a particular entity is ever present + within the session, even if repeated load operations for the same entity occur. This + allows many parts of an application to get a handle to a particular object without + any chance of modifications going to two different places. + + When dealing with instances of mapped classes, an instance may be *attached* to a + particular Session, else it is *unattached* . An instance also may or may not correspond + to an actual row in the database. These conditions break up into four distinct states: + + * *Transient* - an instance that's not in a session, and is not saved to the database; + i.e. it has no database identity. The only relationship such an object has to the ORM + is that its class has a `mapper()` associated with it. + + * *Pending* - when you `save()` a transient instance, it becomes pending. It still + wasn't actually flushed to the database yet, but it will be when the next flush + occurs. + + * *Persistent* - An instance which is present in the session and has a record in the + database. You get persistent instances by either flushing so that the pending + instances become persistent, or by querying the database for existing instances (or + moving persistent instances from other sessions into your local session). + + * *Detached* - an instance which has a record in the database, but is not in any + session. Theres nothing wrong with this, and you can use objects normally when + they're detached, **except** they will not be able to issue any SQL in order to load + collections or attributes which are not yet loaded, or were marked as "expired". + + The session methods which control instance state include ``save()``, ``update()``, + ``save_or_update()``, ``delete()``, ``merge()``, and ``expunge()``. + + The Session object is **not** threadsafe, particularly during flush operations. A session + which is only read from (i.e. is never flushed) can be used by concurrent threads if it's + acceptable that some object instances may be loaded twice. + + The typical pattern to managing Sessions in a multi-threaded environment is either to use + mutexes to limit concurrent access to one thread at a time, or more commonly to establish + a unique session for every thread, using a threadlocal variable. SQLAlchemy provides + a thread-managed Session adapter, provided by the [sqlalchemy.orm#scoped_session()] function. """ - def __init__(self, bind=None, autoflush=False, transactional=False, twophase=False, echo_uow=False, weak_identity_map=False): - self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map) + def __init__(self, bind=None, autoflush=True, transactional=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None): + """Construct a new Session. + + A session is usually constructed using the [sqlalchemy.orm#create_session()] function, + or its more "automated" variant [sqlalchemy.orm#sessionmaker()]. + + autoflush + When ``True``, all query operations will issue a ``flush()`` call to this + ``Session`` before proceeding. This is a convenience feature so that + ``flush()`` need not be called repeatedly in order for database queries to + retrieve results. It's typical that ``autoflush`` is used in conjunction with + ``transactional=True``, so that ``flush()`` is never called; you just call + ``commit()`` when changes are complete to finalize all changes to the + database. + + bind + An optional ``Engine`` or ``Connection`` to which this ``Session`` should be + bound. When specified, all SQL operations performed by this session will + execute via this connectable. + + binds + An optional dictionary, which contains more granular "bind" information than + the ``bind`` parameter provides. This dictionary can map individual ``Table`` + instances as well as ``Mapper`` instances to individual ``Engine`` or + ``Connection`` objects. Operations which proceed relative to a particular + ``Mapper`` will consult this dictionary for the direct ``Mapper`` instance as + well as the mapper's ``mapped_table`` attribute in order to locate an + connectable to use. The full resolution is described in the ``get_bind()`` + method of ``Session``. Usage looks like:: + + sess = Session(binds={ + SomeMappedClass : create_engine('postgres://engine1'), + somemapper : create_engine('postgres://engine2'), + some_table : create_engine('postgres://engine3'), + }) + + Also see the ``bind_mapper()`` and ``bind_table()`` methods. + + echo_uow + When ``True``, configure Python logging to dump all unit-of-work + transactions. This is the equivalent of + ``logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.DEBUG)``. + + extension + An optional [sqlalchemy.orm.session#SessionExtension] instance, which will receive + pre- and post- commit and flush events, as well as a post-rollback event. User- + defined code may be placed within these hooks using a user-defined subclass + of ``SessionExtension``. + + transactional + Set up this ``Session`` to automatically begin transactions. Setting this + flag to ``True`` is the rough equivalent of calling ``begin()`` after each + ``commit()`` operation, after each ``rollback()``, and after each + ``close()``. Basically, this has the effect that all session operations are + performed within the context of a transaction. Note that the ``begin()`` + operation does not immediately utilize any connection resources; only when + connection resources are first required do they get allocated into a + transactional context. + + twophase + When ``True``, all transactions will be started using + [sqlalchemy.engine_TwoPhaseTransaction]. During a ``commit()``, after + ``flush()`` has been issued for all attached databases, the ``prepare()`` + method on each database's ``TwoPhaseTransaction`` will be called. This allows + each database to roll back the entire transaction, before each transaction is + committed. + + weak_identity_map + When set to the default value of ``False``, a weak-referencing map is used; + instances which are not externally referenced will be garbage collected + immediately. For dereferenced instances which have pending changes present, + the attribute management system will create a temporary strong-reference to + the object which lasts until the changes are flushed to the database, at which + point it's again dereferenced. Alternatively, when using the value ``True``, + the identity map uses a regular Python dictionary to store instances. The + session will maintain all instances present until they are removed using + expunge(), clear(), or purge(). + """ + self.echo_uow = echo_uow + self.weak_identity_map = weak_identity_map + self.uow = unitofwork.UnitOfWork(self) + self.identity_map = self.uow.identity_map self.bind = bind self.__binds = {} - self.echo_uow = echo_uow - self.weak_identity_map = weak_identity_map self.transaction = None self.hash_key = id(self) self.autoflush = autoflush - self.transactional = transactional or autoflush + self.transactional = transactional self.twophase = twophase + self.extension = extension self._query_cls = query.Query self._mapper_flush_opts = {} + + if binds is not None: + for mapperortable, value in binds.iteritems(): + if isinstance(mapperortable, type): + mapperortable = _class_mapper(mapperortable).base_mapper + self.__binds[mapperortable] = value + if isinstance(mapperortable, Mapper): + for t in mapperortable._all_tables: + self.__binds[t] = value + if self.transactional: self.begin() _sessions[self.hash_key] = self - - def _get_echo_uow(self): - return self.uow.echo - def _set_echo_uow(self, value): - self.uow.echo = value - echo_uow = property(_get_echo_uow,_set_echo_uow) - def begin(self, **kwargs): """Begin a transaction on this Session.""" @@ -166,75 +500,162 @@ class Session(object): else: self.transaction = SessionTransaction(self, **kwargs) return self.transaction - + create_transaction = begin def begin_nested(self): + """Begin a `nested` transaction on this Session. + + This utilizes a ``SAVEPOINT`` transaction for databases + which support this feature. + """ + return self.begin(nested=True) - + def rollback(self): + """Rollback the current transaction in progress. + + If no transaction is in progress, this method is a + pass-thru. + """ + if self.transaction is None: - raise exceptions.InvalidRequestError("No transaction is begun.") + pass else: - self.transaction = self.transaction.rollback() + self.transaction.rollback() + # TODO: we can rollback attribute values. however + # we would want to expand attributes.py to be able to save *two* rollback points, one to the + # last flush() and the other to when the object first entered the transaction. + # [ticket:705] + #attributes.rollback(*self.identity_map.values()) if self.transaction is None and self.transactional: self.begin() - + def commit(self): + """Commit the current transaction in progress. + + If no transaction is in progress, this method raises + an InvalidRequestError. + + If the ``begin()`` method was called on this ``Session`` + additional times subsequent to its first call, + ``commit()`` will not actually commit, and instead + pops an internal SessionTransaction off its internal stack + of transactions. Only when the "root" SessionTransaction + is reached does an actual database-level commit occur. + """ + if self.transaction is None: - raise exceptions.InvalidRequestError("No transaction is begun.") - else: - self.transaction = self.transaction.commit() + if self.transactional: + self.begin() + else: + raise exceptions.InvalidRequestError("No transaction is begun.") + + self.transaction.commit() if self.transaction is None and self.transactional: self.begin() - def connection(self, mapper=None, **kwargs): + def prepare(self): + """Prepare the current transaction in progress for two phase commit. + + If no transaction is in progress, this method raises + an InvalidRequestError. + + Only root transactions of two phase sessions can be prepared. If the current transaction is + not such, an InvalidRequestError is raised. + """ + if self.transaction is None: + if self.transactional: + self.begin() + else: + raise exceptions.InvalidRequestError("No transaction is begun.") + + self.transaction.prepare() + + def connection(self, mapper=None, clause=None, instance=None): """Return a ``Connection`` corresponding to this session's transactional context, if any. If this ``Session`` is transactional, the connection will be in the context of this session's transaction. Otherwise, the - connection is returned by the ``contextual_connect()`` method, which - some Engines override to return a thread-local connection, and - will have `close_with_result` set to `True`. + connection is returned by the ``contextual_connect()`` method + on the engine. - The given `**kwargs` will be sent to the engine's - ``contextual_connect()`` method, if no transaction is in - progress. - - the "mapper" argument is a class or mapper to which a bound engine - will be located; use this when the Session itself is unbound. + The `mapper` argument is a class or mapper to which a bound engine + will be located; use this when the Session itself is either bound + to multiple engines or connections, or is not bound to any connectable. + + \**kwargs are additional arguments which will be passed to get_bind(). + See the get_bind() method for details. Note that the ``ShardedSession`` + subclass takes a different get_bind() argument signature. """ + return self.__connection(self.get_bind(mapper, clause, instance)) + + def __connection(self, engine, **kwargs): if self.transaction is not None: - return self.transaction.connection(mapper) + return self.transaction.get_or_add(engine) else: - return self.get_bind(mapper).contextual_connect(**kwargs) - - def execute(self, clause, params=None, mapper=None, **kwargs): - """Using the given mapper to identify the appropriate ``Engine`` - or ``Connection`` to be used for statement execution, execute the - given ``ClauseElement`` using the provided parameter dictionary. + return engine.contextual_connect(**kwargs) - Return a ``ResultProxy`` corresponding to the execution's results. + def execute(self, clause, params=None, mapper=None, instance=None): + """Execute the given clause, using the current transaction (if any). - If this method allocates a new ``Connection`` for the operation, - then the ``ResultProxy`` 's ``close()`` method will release the - resources of the underlying ``Connection``, otherwise its a no-op. + Returns a ``ResultProxy`` corresponding to the execution's results. + + clause + a ClauseElement (i.e. select(), text(), etc.) or + string SQL statement to be executed + + params + a dictionary of bind parameters. + + mapper + a mapped class or Mapper instance which may be needed + in order to locate the proper bind. This is typically + if the Session is not directly bound to a single engine. + + instance + used by some Query operations to further identify + the proper bind, in the case of ShardedSession. + """ - return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs) + engine = self.get_bind(mapper, clause=clause, instance=instance) - def scalar(self, clause, params=None, mapper=None, **kwargs): + return self.__connection(engine, close_with_result=True).execute(clause, params or {}) + + def scalar(self, clause, params=None, mapper=None, instance=None): """Like execute() but return a scalar result.""" - return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs) + engine = self.get_bind(mapper, clause=clause, instance=instance) + + return self.__connection(engine, close_with_result=True).scalar(clause, params or {}) def close(self): - """Close this Session.""" + """Close this Session. + + This clears all items and ends any transaction in progress. + + If this session were created with ``transactional=True``, a + new transaction is immediately begun. Note that this new + transaction does not use any connection resources until they + are first needed. + """ self.clear() if self.transaction is not None: - self.transaction.close() + for transaction in self.transaction._iterate_parents(): + transaction.close() + if self.transactional: + # note this doesnt use any connection resources + self.begin() + + def close_all(cls): + """Close *all* sessions in memory.""" + + for sess in _sessions.values(): + sess.close() + close_all = classmethod(close_all) def clear(self): """Remove all object instances from this ``Session``. @@ -242,18 +663,13 @@ class Session(object): This is equivalent to calling ``expunge()`` for all objects in this ``Session``. """ - + for instance in self: self._unattach(instance) - echo = self.uow.echo - self.uow = unitofwork.UnitOfWork(weak_identity_map=self.weak_identity_map) - self.uow.echo = echo + self.uow = unitofwork.UnitOfWork(self) + self.identity_map = self.uow.identity_map - def mapper(self, class_, entity_name=None): - """Given a ``Class``, return the primary ``Mapper`` responsible for - persisting it.""" - - return _class_mapper(class_, entity_name = entity_name) + # TODO: need much more test coverage for bind_mapper() and similar ! def bind_mapper(self, mapper, bind, entity_name=None): """Bind the given `mapper` or `class` to the given ``Engine`` or ``Connection``. @@ -261,11 +677,13 @@ class Session(object): All subsequent operations involving this ``Mapper`` will use the given `bind`. """ - + if isinstance(mapper, type): mapper = _class_mapper(mapper, entity_name=entity_name) - self.__binds[mapper] = bind + self.__binds[mapper.base_mapper] = bind + for t in mapper._all_tables: + self.__binds[t] = bind def bind_table(self, table, bind): """Bind the given `table` to the given ``Engine`` or ``Connection``. @@ -276,83 +694,95 @@ class Session(object): self.__binds[table] = bind - def get_bind(self, mapper): - """Return the ``Engine`` or ``Connection`` which is used to execute - statements on behalf of the given `mapper`. + def get_bind(self, mapper, clause=None, instance=None): + """Return an engine corresponding to the given arguments. - Calling ``connect()`` on the return result will always result - in a ``Connection`` object. This method disregards any - ``SessionTransaction`` that may be in progress. + mapper + mapper relative to the desired operation. - The order of searching is as follows: - - 1. if an ``Engine`` or ``Connection`` was bound to this ``Mapper`` - specifically within this ``Session``, return that ``Engine`` or - ``Connection``. - - 2. if an ``Engine`` or ``Connection`` was bound to this `mapper` 's - underlying ``Table`` within this ``Session`` (i.e. not to the ``Table`` - directly), return that ``Engine`` or ``Connection``. - - 3. if an ``Engine`` or ``Connection`` was bound to this ``Session``, - return that ``Engine`` or ``Connection``. - - 4. finally, return the ``Engine`` which was bound directly to the - ``Table`` 's ``MetaData`` object. - - If no ``Engine`` is bound to the ``Table``, an exception is raised. + clause + a ClauseElement which is to be executed. if + mapper is not present, this may be used to locate + Table objects, which are then associated with mappers + which have associated binds. + + instance + an ORM mapped instance which may be used to further + locate the correct bind. This is currently used by + the ShardedSession subclass. + """ - - if mapper is None: + if mapper is None and clause is None: if self.bind is not None: return self.bind else: - raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()") - elif self.__binds.has_key(mapper): - return self.__binds[mapper] - elif self.__binds.has_key(mapper.mapped_table): - return self.__binds[mapper.mapped_table] - elif self.bind is not None: + raise exceptions.UnboundExecutionError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()") + + elif len(self.__binds): + if mapper is not None: + if isinstance(mapper, type): + mapper = _class_mapper(mapper) + if mapper.base_mapper in self.__binds: + return self.__binds[mapper.base_mapper] + elif mapper.compile().mapped_table in self.__binds: + return self.__binds[mapper.mapped_table] + if clause is not None: + for t in clause._table_iterator(): + if t in self.__binds: + return self.__binds[t] + + if self.bind is not None: return self.bind + elif isinstance(clause, sql.expression.ClauseElement) and clause.bind is not None: + return clause.bind + elif mapper is None: + raise exceptions.UnboundExecutionError("Could not locate any mapper associated with SQL expression") else: + if isinstance(mapper, type): + mapper = _class_mapper(mapper) + else: + mapper = mapper.compile() e = mapper.mapped_table.bind if e is None: - raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper)) + raise exceptions.UnboundExecutionError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper)) return e def query(self, mapper_or_class, *addtl_entities, **kwargs): """Return a new ``Query`` object corresponding to this ``Session`` and the mapper, or the classes' primary mapper. + """ - entity_name = kwargs.pop('entity_name', None) - + if isinstance(mapper_or_class, type): q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs) else: q = self._query_cls(mapper_or_class, self, **kwargs) - + for ent in addtl_entities: q = q.add_entity(ent) return q - def _sql(self): - class SQLProxy(object): - def __getattr__(self, key): - def call(*args, **kwargs): - kwargs[engine] = self.engine - return getattr(sql, key)(*args, **kwargs) - - sql = property(_sql) + def _autoflush(self): + if self.autoflush and (self.transaction is None or self.transaction.autoflush): + self.flush() def flush(self, objects=None): """Flush all the object modifications present in this session to the database. - `objects` is a list or tuple of objects specifically to be + `objects` is a collection or iterator of objects specifically to be flushed; if ``None``, all new and modified objects are flushed. - """ + """ + if objects is not None: + try: + if not len(objects): + return + except TypeError: + objects = list(objects) + if not objects: + return self.uow.flush(self, objects) def get(self, class_, ident, **kwargs): @@ -389,58 +819,90 @@ class Session(object): entity_name = kwargs.pop('entity_name', None) return self.query(class_, entity_name=entity_name).load(ident, **kwargs) - def refresh(self, obj): - """Reload the attributes for the given object from the - database, clear any changes made. - """ + def refresh(self, instance, attribute_names=None): + """Refresh the attributes on the given instance. + + When called, a query will be issued + to the database which will refresh all attributes with their + current value. - self._validate_persistent(obj) - if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None: - raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % repr(obj)) + Lazy-loaded relational attributes will remain lazily loaded, so that + the instance-wide refresh operation will be followed + immediately by the lazy load of that attribute. - def expire(self, obj): - """Mark the given object as expired. + Eagerly-loaded relational attributes will eagerly load within the + single refresh operation. - This will add an instrumentation to all mapped attributes on - the instance such that when an attribute is next accessed, the - session will reload all attributes on the instance from the - database. + The ``attribute_names`` argument is an iterable collection + of attribute names indicating a subset of attributes to be + refreshed. """ - for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)): - self._expire_impl(c) + self._validate_persistent(instance) - def _expire_impl(self, obj): - self._validate_persistent(obj) + if self.query(_object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None: + raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) + + def expire_all(self): + """Expires all persistent instances within this Session. + + """ + for state in self.identity_map.all_states(): + _expire_state(state, None) + + def expire(self, instance, attribute_names=None): + """Expire the attributes on the given instance. - def exp(): - if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None: - raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % repr(obj)) + The instance's attributes are instrumented such that + when an attribute is next accessed, a query will be issued + to the database which will refresh all attributes with their + current value. - attribute_manager.trigger_history(obj, exp) + The ``attribute_names`` argument is an iterable collection + of attribute names indicating a subset of attributes to be + expired. + """ - def is_expired(self, obj, unexpire=False): - """Return True if the given object has been marked as expired.""" + if attribute_names: + self._validate_persistent(instance) + _expire_state(instance._state, attribute_names=attribute_names) + else: + # pre-fetch the full cascade since the expire is going to + # remove associations + cascaded = list(_cascade_iterator('refresh-expire', instance)) + self._validate_persistent(instance) + _expire_state(instance._state, None) + for (c, m) in cascaded: + self._validate_persistent(c) + _expire_state(c._state, None) + + def prune(self): + """Remove unreferenced instances cached in the identity map. + + Note that this method is only meaningful if "weak_identity_map" + is set to False. + + Removes any object in this Session's identity map that is not + referenced in user code, modified, new or scheduled for deletion. + Returns the number of objects pruned. + """ - ret = attribute_manager.has_trigger(obj) - if ret and unexpire: - attribute_manager.untrigger_history(obj) - return ret + return self.uow.prune_identity_map() - def expunge(self, object): - """Remove the given `object` from this ``Session``. + def expunge(self, instance): + """Remove the given `instance` from this ``Session``. - This will free all internal references to the object. + This will free all internal references to the instance. Cascading will be applied according to the *expunge* cascade rule. """ - self._validate_persistent(object) - for c in [object] + list(_object_mapper(object).cascade_iterator('expunge', object)): + self._validate_persistent(instance) + for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)): if c in self: - self.uow._remove_deleted(c) + self.uow._remove_deleted(c._state) self._unattach(c) - def save(self, object, entity_name=None): + def save(self, instance, entity_name=None): """Add a transient (unsaved) instance to this ``Session``. This operation cascades the `save_or_update` method to @@ -450,60 +912,53 @@ class Session(object): The `entity_name` keyword argument will further qualify the specific ``Mapper`` used to handle this instance. """ + self._save_impl(instance, entity_name=entity_name) + self._cascade_save_or_update(instance) - self._save_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, - lambda c, e:self._save_or_update_impl(c, e), - halt_on=lambda c:c in self) - - def update(self, object, entity_name=None): + def update(self, instance, entity_name=None): """Bring the given detached (saved) instance into this ``Session``. - If there is a persistent instance with the same identifier - already associated with this ``Session``, an exception is thrown. + If there is a persistent instance with the same instance key, but + different identity already associated with this ``Session``, an + InvalidRequestError exception is thrown. This operation cascades the `save_or_update` method to associated instances if the relation is mapped with ``cascade="save-update"``. """ - self._update_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, - lambda c, e:self._save_or_update_impl(c, e), - halt_on=lambda c:c in self) + self._update_impl(instance, entity_name=entity_name) + self._cascade_save_or_update(instance) - def save_or_update(self, object, entity_name=None): - """Save or update the given object into this ``Session``. + def save_or_update(self, instance, entity_name=None): + """Save or update the given instance into this ``Session``. The presence of an `_instance_key` attribute on the instance determines whether to ``save()`` or ``update()`` the instance. """ - self._save_or_update_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, - lambda c, e:self._save_or_update_impl(c, e), - halt_on=lambda c:c in self) + self._save_or_update_impl(instance, entity_name=entity_name) + self._cascade_save_or_update(instance) - def _save_or_update_impl(self, object, entity_name=None): - key = getattr(object, '_instance_key', None) - if key is None: - self._save_impl(object, entity_name=entity_name) - else: - self._update_impl(object, entity_name=entity_name) + def _cascade_save_or_update(self, instance): + for obj, mapper in _cascade_iterator('save-update', instance, halt_on=lambda c:c in self): + self._save_or_update_impl(obj, mapper.entity_name) - def delete(self, object): + def delete(self, instance): """Mark the given instance as deleted. The delete operation occurs upon ``flush()``. """ - for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)): - self.uow.register_deleted(c) + self._delete_impl(instance) + for c, m in _cascade_iterator('delete', instance): + self._delete_impl(c, ignore_transient=True) - def merge(self, object, entity_name=None, _recursive=None): - """Copy the state of the given `object` onto the persistent - object with the same identifier. + + def merge(self, instance, entity_name=None, dont_load=False, _recursive=None): + """Copy the state of the given `instance` onto the persistent + instance with the same identifier. If there is no persistent instance currently associated with the session, it will be loaded. Return the persistent @@ -516,41 +971,57 @@ class Session(object): """ if _recursive is None: - _recursive = util.Set() + _recursive = {} # TODO: this should be an IdentityDict for instances, but will need a separate + # dict for PropertyLoader tuples if entity_name is not None: - mapper = _class_mapper(object.__class__, entity_name=entity_name) + mapper = _class_mapper(instance.__class__, entity_name=entity_name) else: - mapper = _object_mapper(object) - if mapper in _recursive or object in _recursive: - return None - _recursive.add(mapper) - _recursive.add(object) - try: - key = getattr(object, '_instance_key', None) - if key is None: - merged = mapper._create_instance(self) + mapper = _object_mapper(instance) + if instance in _recursive: + return _recursive[instance] + + key = getattr(instance, '_instance_key', None) + if key is None: + if dont_load: + raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects. flush() all changes on mapped instances before merging with dont_load=True.") + key = mapper.identity_key_from_instance(instance) + + merged = None + if key: + if key in self.identity_map: + merged = self.identity_map[key] + elif dont_load: + if instance._state.modified: + raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True.") + + merged = attributes.new_instance(mapper.class_) + merged._instance_key = key + merged._entity_name = entity_name + self._update_impl(merged, entity_name=mapper.entity_name) else: - if key in self.identity_map: - merged = self.identity_map[key] - else: - merged = self.get(mapper.class_, key[1]) - if merged is None: - raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object)) - for prop in mapper.iterate_properties: - prop.merge(self, object, merged, _recursive) - if key is None: - self.save(merged, entity_name=mapper.entity_name) - return merged - finally: - _recursive.remove(mapper) + merged = self.get(mapper.class_, key[1]) + + if merged is None: + merged = attributes.new_instance(mapper.class_) + self.save(merged, entity_name=mapper.entity_name) - def identity_key(self, *args, **kwargs): + _recursive[instance] = merged + + for prop in mapper.iterate_properties: + prop.merge(self, instance, merged, dont_load, _recursive) + + if dont_load: + merged._state.commit_all() # remove any history + + return merged + + def identity_key(cls, *args, **kwargs): """Get an identity key. Valid call signatures: * ``identity_key(class, ident, entity_name=None)`` - + class mapped class (must be a positional argument) @@ -561,12 +1032,12 @@ class Session(object): optional entity name * ``identity_key(instance=instance)`` - + instance object instance (must be given as a keyword arg) * ``identity_key(class, row=row, entity_name=None)`` - + class mapped class (must be a positional argument) @@ -606,133 +1077,178 @@ class Session(object): % ", ".join(kwargs.keys())) mapper = _object_mapper(instance) return mapper.identity_key_from_instance(instance) + identity_key = classmethod(identity_key) - def _save_impl(self, object, **kwargs): - if hasattr(object, '_instance_key'): - if not self.identity_map.has_key(object._instance_key): - raise exceptions.InvalidRequestError("Instance '%s' is a detached instance " - "or is already persistent in a " - "different Session" % repr(object)) - else: - m = _class_mapper(object.__class__, entity_name=kwargs.get('entity_name', None)) + def object_session(cls, instance): + """Return the ``Session`` to which the given object belongs.""" - # this would be a nice exception to raise...however this is incompatible with a contextual - # session which puts all objects into the session upon construction. - #if m._is_orphan(object): - # raise exceptions.InvalidRequestError("Instance '%s' is an orphan, " - # "and must be attached to a parent " - # "object to be saved" % (repr(object))) + return object_session(instance) + object_session = classmethod(object_session) - m._assign_entity_name(object) - self._register_pending(object) - - def _update_impl(self, object, **kwargs): - if self._is_attached(object) and object not in self.deleted: + def _save_impl(self, instance, **kwargs): + if hasattr(instance, '_instance_key'): + raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(instance)) + else: + # TODO: consolidate the steps here + attributes.manage(instance) + instance._entity_name = kwargs.get('entity_name', None) + self._attach(instance) + self.uow.register_new(instance) + + def _update_impl(self, instance, **kwargs): + if instance in self and instance not in self.deleted: return - if not hasattr(object, '_instance_key'): - raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % repr(object)) - self._attach(object) - - def _register_pending(self, obj): - self._attach(obj) - self.uow.register_new(obj) - - def _register_persistent(self, obj): - self._attach(obj) - self.uow.register_clean(obj) - - def _register_deleted(self, obj): - self._attach(obj) - self.uow.register_deleted(obj) - - def _attach(self, obj): - """Attach the given object to this ``Session``.""" + if not hasattr(instance, '_instance_key'): + raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance)) + elif self.identity_map.get(instance._instance_key, instance) is not instance: + raise exceptions.InvalidRequestError("Could not update instance '%s', identity key %s; a different instance with the same identity key already exists in this session." % (mapperutil.instance_str(instance), instance._instance_key)) + self._attach(instance) + + def _save_or_update_impl(self, instance, entity_name=None): + key = getattr(instance, '_instance_key', None) + if key is None: + self._save_impl(instance, entity_name=entity_name) + else: + self._update_impl(instance, entity_name=entity_name) - old_id = getattr(obj, '_sa_session_id', None) + def _delete_impl(self, instance, ignore_transient=False): + if instance in self and instance in self.deleted: + return + if not hasattr(instance, '_instance_key'): + if ignore_transient: + return + else: + raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance)) + if self.identity_map.get(instance._instance_key, instance) is not instance: + raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(instance), instance._instance_key)) + self._attach(instance) + self.uow.register_deleted(instance) + + def _attach(self, instance): + old_id = getattr(instance, '_sa_session_id', None) if old_id != self.hash_key: - if old_id is not None and _sessions.has_key(old_id): + if old_id is not None and old_id in _sessions and instance in _sessions[old_id]: raise exceptions.InvalidRequestError("Object '%s' is already attached " "to session '%s' (this is '%s')" % - (repr(obj), old_id, id(self))) - - # auto-removal from the old session is disabled. but if we decide to - # turn it back on, do it as below: gingerly since _sessions is a WeakValueDict - # and it might be affected by other threads - #try: - # sess = _sessions[old] - #except KeyError: - # sess = None - #if sess is not None: - # sess.expunge(old) - key = getattr(obj, '_instance_key', None) + (mapperutil.instance_str(instance), old_id, id(self))) + + key = getattr(instance, '_instance_key', None) if key is not None: - self.identity_map[key] = obj - obj._sa_session_id = self.hash_key + self.identity_map[key] = instance + instance._sa_session_id = self.hash_key - def _unattach(self, obj): - if not self._is_attached(obj): - raise exceptions.InvalidRequestError("Instance '%s' not attached to this Session" % repr(obj)) - del obj._sa_session_id + def _unattach(self, instance): + if instance._sa_session_id == self.hash_key: + del instance._sa_session_id - def _validate_persistent(self, obj): - """Validate that the given object is persistent within this + def _validate_persistent(self, instance): + """Validate that the given instance is persistent within this ``Session``. """ - self.uow._validate_obj(obj) + if instance not in self: + raise exceptions.InvalidRequestError("Instance '%s' is not persistent within this Session" % mapperutil.instance_str(instance)) + + def __contains__(self, instance): + """Return True if the given instance is associated with this session. - def _is_attached(self, obj): - return getattr(obj, '_sa_session_id', None) == self.hash_key + The instance may be pending or persistent within the Session for a + result of True. + """ - def __contains__(self, obj): - return self._is_attached(obj) and (obj in self.uow.new or self.identity_map.has_key(obj._instance_key)) + return instance._state in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance) def __iter__(self): - return iter(list(self.uow.new) + self.uow.identity_map.values()) + """Return an iterator of all instances which are pending or persistent within this Session.""" + + return iter(list(self.uow.new.values()) + self.uow.identity_map.values()) + + def is_modified(self, instance, include_collections=True, passive=False): + """Return True if the given instance has modified attributes. + + This method retrieves a history instance for each instrumented attribute + on the instance and performs a comparison of the current value to its + previously committed value. Note that instances present in the 'dirty' + collection may result in a value of ``False`` when tested with this method. + + `include_collections` indicates if multivalued collections should be included + in the operation. Setting this to False is a way to detect only local-column + based properties (i.e. scalar columns or many-to-one foreign keys) that would + result in an UPDATE for this instance upon flush. - def _get(self, key): - return self.identity_map[key] + The `passive` flag indicates if unloaded attributes and collections should + not be loaded in the course of performing this test. + """ - def has_key(self, key): - return self.identity_map.has_key(key) + for attr in attributes._managed_attributes(instance.__class__): + if not include_collections and hasattr(attr.impl, 'get_collection'): + continue + (added, unchanged, deleted) = attr.get_history(instance) + if added or deleted: + return True + return False + + def dirty(self): + """Return a ``Set`` of all instances marked as 'dirty' within this ``Session``. + + Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection + modification operations will mark an instance as 'dirty' and place it in this set, + even if there is no net change to the attribute's value. At flush time, the value + of each attribute is compared to its previously saved value, + and if there's no net change, no SQL operation will occur (this is a more expensive + operation so it's only done at flush time). + + To check if an instance has actionable net changes to its attributes, use the + is_modified() method. + """ - dirty = property(lambda s:s.uow.locate_dirty(), - doc="A ``Set`` of all objects marked as 'dirty' within this ``Session``") + return self.uow.locate_dirty() + dirty = property(dirty) - deleted = property(lambda s:s.uow.deleted, - doc="A ``Set`` of all objects marked as 'deleted' within this ``Session``") + def deleted(self): + "Return a ``Set`` of all instances marked as 'deleted' within this ``Session``" + + return util.IdentitySet(self.uow.deleted.values()) + deleted = property(deleted) - new = property(lambda s:s.uow.new, - doc="A ``Set`` of all objects marked as 'new' within this ``Session``.") + def new(self): + "Return a ``Set`` of all instances marked as 'new' within this ``Session``." + + return util.IdentitySet(self.uow.new.values()) + new = property(new) - identity_map = property(lambda s:s.uow.identity_map, - doc="A dictionary consisting of all objects " - "within this ``Session`` keyed to their `_instance_key` value.") +def _expire_state(state, attribute_names): + """Standalone expire instance function. - def import_instance(self, *args, **kwargs): - """Deprecated. A synynom for ``merge()``.""" + Installs a callable with the given instance's _state + which will fire off when any of the named attributes are accessed; + their existing value is removed. - return self.merge(*args, **kwargs) + If the list is None or blank, the entire instance is expired. + """ + + state.expire_attributes(attribute_names) -# this is the AttributeManager instance used to provide attribute behavior on objects. -# to all the "global variable police" out there: its a stateless object. -attribute_manager = unitofwork.attribute_manager +register_attribute = unitofwork.register_attribute -# this dictionary maps the hash key of a Session to the Session itself, and -# acts as a Registry with which to locate Sessions. this is to enable -# object instances to be associated with Sessions without having to attach the -# actual Session object directly to the object instance. _sessions = weakref.WeakValueDictionary() -def object_session(obj): - """Return the ``Session`` to which the given object is bound, or ``None`` if none.""" +def _cascade_iterator(cascade, instance, **kwargs): + mapper = _object_mapper(instance) + for (o, m) in mapper.cascade_iterator(cascade, instance._state, **kwargs): + yield o, m + +def object_session(instance): + """Return the ``Session`` to which the given instance is bound, or ``None`` if none.""" - hashkey = getattr(obj, '_sa_session_id', None) + hashkey = getattr(instance, '_sa_session_id', None) if hashkey is not None: - return _sessions.get(hashkey) + sess = _sessions.get(hashkey) + if sess is not None and instance in sess: + return sess return None # Lazy initialization to avoid circular imports unitofwork.object_session = object_session from sqlalchemy.orm import mapper -mapper.attribute_manager = attribute_manager +mapper._expire_state = _expire_state diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py index cc13f8c1fe..7cf4eb2cc5 100644 --- a/lib/sqlalchemy/orm/shard.py +++ b/lib/sqlalchemy/orm/shard.py @@ -1,21 +1,32 @@ +"""Defines a rudimental 'horizontal sharding' system which allows a +Session to distribute queries and persistence operations across multiple +databases. + +For a usage example, see the example ``examples/sharding/attribute_shard.py``. + +""" from sqlalchemy.orm.session import Session -from sqlalchemy.orm import Query +from sqlalchemy.orm.query import Query +from sqlalchemy import exceptions, util + +__all__ = ['ShardedSession', 'ShardedQuery'] class ShardedSession(Session): - def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs): + def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs): """construct a ShardedSession. shard_chooser - a callable which, passed a Mapper and a mapped instance, returns a - shard ID. this id may be based off of the attributes present within the - object, or on some round-robin scheme. If the scheme is based on a - selection, it should set whatever state on the instance to mark it in - the future as participating in that shard. + a callable which, passed a Mapper, a mapped instance, and possibly a + SQL clause, returns a shard ID. this id may be based off of the + attributes present within the object, or on some round-robin scheme. If + the scheme is based on a selection, it should set whatever state on the + instance to mark it in the future as participating in that shard. id_chooser - a callable, passed a tuple of identity values, which should return - a list of shard ids where the ID might reside. The databases will - be queried in the order of this listing. + a callable, passed a query and a tuple of identity values, + which should return a list of shard ids where the ID might + reside. The databases will be queried in the order of this + listing. query_chooser for a given Query, returns the list of shard_ids where the query @@ -30,6 +41,9 @@ class ShardedSession(Session): self.__binds = {} self._mapper_flush_opts = {'connection_callable':self.connection} self._query_cls = ShardedQuery + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): if shard_id is None: @@ -40,9 +54,9 @@ class ShardedSession(Session): else: return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs) - def get_bind(self, mapper, shard_id=None, instance=None): + def get_bind(self, mapper, shard_id=None, instance=None, clause=None): if shard_id is None: - shard_id = self.shard_chooser(mapper, instance) + shard_id = self.shard_chooser(mapper, instance, clause=clause) return self.__binds[shard_id] def bind_shard(self, shard_id, bind): @@ -71,19 +85,19 @@ class ShardedQuery(Query): q._shard_id = shard_id return q - def _execute_and_instances(self, statement): + def _execute_and_instances(self, context): if self._shard_id is not None: - result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(statement, **self._params) + result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(context.statement, **self._params) try: - return iter(self.instances(result)) + return iter(self.instances(result, querycontext=context)) finally: result.close() else: partial = [] for shard_id in self.query_chooser(self): - result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(statement, **self._params) + result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(context.statement, **self._params) try: - partial = partial + list(self.instances(result)) + partial = partial + list(self.instances(result, querycontext=context)) finally: result.close() # if some kind of in memory 'sorting' were done, this is where it would happen @@ -93,7 +107,8 @@ class ShardedQuery(Query): if self._shard_id is not None: return super(ShardedQuery, self).get(ident) else: - for shard_id in self.id_chooser(ident): + ident = util.to_list(ident) + for shard_id in self.id_chooser(self, ident): o = self.set_shard(shard_id).get(ident, **kwargs) if o is not None: return o @@ -104,7 +119,7 @@ class ShardedQuery(Query): if self._shard_id is not None: return super(ShardedQuery, self).load(ident) else: - for shard_id in self.id_chooser(ident): + for shard_id in self.id_chooser(self, ident): o = self.set_shard(shard_id).load(ident, raiseerr=False, **kwargs) if o is not None: return o diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index babd6e4c09..65a8b019b8 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1,19 +1,23 @@ # strategies.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" -from sqlalchemy import sql, util, exceptions, sql_util, logging +from sqlalchemy import sql, util, exceptions, logging +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import visitors, expression, operators from sqlalchemy.orm import mapper, attributes -from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption +from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, serialize_path, deserialize_path from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil class ColumnLoader(LoaderStrategy): + """Default column loader.""" + def init(self): super(ColumnLoader, self).init() self.columns = self.parent_property.columns @@ -23,11 +27,12 @@ class ColumnLoader(LoaderStrategy): def setup_query(self, context, parentclauses=None, **kwargs): for c in self.columns: if parentclauses is not None: - context.statement.append_column(parentclauses.aliased_column(c)) + context.secondary_columns.append(parentclauses.aliased_column(c)) else: - context.statement.append_column(c) + context.primary_columns.append(c) def init_class_attribute(self): + self.is_class_level = True if self.is_composite: self._init_composite_attribute() else: @@ -36,19 +41,22 @@ class ColumnLoader(LoaderStrategy): def _init_composite_attribute(self): self.logger.info("register managed composite attribute %s on class %s" % (self.key, self.parent.class_.__name__)) def copy(obj): - return self.parent_property.composite_class(*obj.__colset__()) + return self.parent_property.composite_class( + *obj.__composite_values__()) def compare(a, b): - for col, aprop, bprop in zip(self.columns, a.__colset__(), b.__colset__()): + for col, aprop, bprop in zip(self.columns, + a.__composite_values__(), + b.__composite_values__()): if not col.type.compare_values(aprop, bprop): return False else: return True - sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator) + sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator) def _init_scalar_attribute(self): self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) coltype = self.columns[0].type - sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) + sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) def create_row_processor(self, selectcontext, mapper, row): if self.is_composite: @@ -56,91 +64,50 @@ class ColumnLoader(LoaderStrategy): if c not in row: break else: - def execute(instance, row, isnew, ispostselect=None, **flags): - if isnew or ispostselect: - if self._should_log_debug: - self.logger.debug("populating %s with %s/%s..." % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) - instance.__dict__[self.key] = self.parent_property.composite_class(*[row[c] for c in self.columns]) - self.logger.debug("Returning active composite column fetcher for %s %s" % (mapper, self.key)) - return (execute, None) + def new_execute(instance, row, **flags): + if self._should_log_debug: + self.logger.debug("populating %s with %s/%s..." % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) + instance.__dict__[self.key] = self.parent_property.composite_class(*[row[c] for c in self.columns]) + if self._should_log_debug: + self.logger.debug("Returning active composite column fetcher for %s %s" % (mapper, self.key)) + return (new_execute, None, None) elif self.columns[0] in row: - def execute(instance, row, isnew, ispostselect=None, **flags): - if isnew or ispostselect: - if self._should_log_debug: - self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) - instance.__dict__[self.key] = row[self.columns[0]] - self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key)) - return (execute, None) - - (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None)) - if hosted_mapper is None: - return (None, None) - - if hosted_mapper.polymorphic_fetch == 'deferred': - def execute(instance, row, isnew, **flags): + def new_execute(instance, row, **flags): + if self._should_log_debug: + self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) + instance.__dict__[self.key] = row[self.columns[0]] + if self._should_log_debug: + self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key)) + return (new_execute, None, None) + else: + def new_execute(instance, row, isnew, **flags): if isnew: - sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_loader(instance, mapper, needs_tables)) - self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key)) - return (execute, None) - else: - self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key)) - return (None, None) - - def _get_deferred_loader(self, instance, mapper, needs_tables): - def load(): - group = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables] - + instance._state.expire_attributes([self.key]) if self._should_log_debug: - self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None')) - - session = sessionlib.object_session(instance) - if session is None: - raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) - - cond, param_names = mapper._deferred_inheritance_condition(needs_tables) - statement = sql.select(needs_tables, cond, use_labels=True) - params = {} - for c in param_names: - params[c.name] = mapper.get_attr_by_column(instance, c) - - result = session.execute(statement, params, mapper=mapper) - try: - row = result.fetchone() - for prop in group: - sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) - return attributes.ATTR_WAS_SET - finally: - result.close() - - return load + self.logger.debug("Deferring load for %s %s" % (mapper, self.key)) + return (new_execute, None, None) ColumnLoader.logger = logging.class_logger(ColumnLoader) class DeferredColumnLoader(LoaderStrategy): - """Describes an object attribute that corresponds to a table - column, which also will *lazy load* its value from the table. - - This is per-column lazy loading. - """ + """Deferred column loader, a per-column or per-column-group lazy loader.""" def create_row_processor(self, selectcontext, mapper, row): - if self.group is not None and selectcontext.attributes.get(('undefer', self.group), False): + if self.columns[0] in row: return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row) - elif not self.is_default or len(selectcontext.options): - def execute(instance, row, isnew, **flags): - if isnew: - if self._should_log_debug: - self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) - sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self.setup_loader(instance)) - return (execute, None) + elif not self.is_class_level or len(selectcontext.options): + def new_execute(instance, row, **flags): + if self._should_log_debug: + self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) + instance._state.set_callable(self.key, self.setup_loader(instance)) + return (new_execute, None, None) else: - def execute(instance, row, isnew, **flags): - if isnew: - if self._should_log_debug: - self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) - sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) - return (execute, None) + def new_execute(instance, row, **flags): + if self._should_log_debug: + self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) + instance._state.reset(self.key) + return (new_execute, None, None) def init(self): super(DeferredColumnLoader, self).init() @@ -151,63 +118,90 @@ class DeferredColumnLoader(LoaderStrategy): self._should_log_debug = logging.is_debug_enabled(self.logger) def init_class_attribute(self): + self.is_class_level = True self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) - sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) + sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) - def setup_query(self, context, **kwargs): - if self.group is not None and context.attributes.get(('undefer', self.group), False): + def setup_query(self, context, only_load_props=None, **kwargs): + if \ + (self.group is not None and context.attributes.get(('undefer', self.group), False)) or \ + (only_load_props and self.key in only_load_props): + self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs) - - def setup_loader(self, instance): - localparent = mapper.object_mapper(instance, raiseerror=False) - if localparent is None: + + def class_level_loader(self, instance, props=None): + if not mapper.has_mapper(instance): return None + localparent = mapper.object_mapper(instance) + + # adjust for the ColumnProperty associated with the instance + # not being our own ColumnProperty. This can occur when entity_name + # mappers are used to map different versions of the same ColumnProperty + # to the class. prop = localparent.get_property(self.key) if prop is not self.parent_property: return prop._get_strategy(DeferredColumnLoader).setup_loader(instance) - def lazyload(): - if not mapper.has_identity(instance): - return None - - if self.group is not None: - group = [p for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group] - else: - group = None + return LoadDeferredColumns(instance, self.key, props) + + def setup_loader(self, instance, props=None, create_statement=None): + return LoadDeferredColumns(instance, self.key, props, optimizing_statement=create_statement) - if self._should_log_debug: - self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None')) +DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader) - session = sessionlib.object_session(instance) - if session is None: - raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) - - clause = localparent._get_clause - ident = instance._instance_key[1] - params = {} - for i, primary_key in enumerate(localparent.primary_key): - params[primary_key._label] = ident[i] - if group is not None: - statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True) - else: - statement = sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True) - - if group is not None: - result = session.execute(statement, params, mapper=localparent) - try: - row = result.fetchone() - for prop in group: - sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) - return attributes.ATTR_WAS_SET - finally: - result.close() - else: - return session.scalar(sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True),params, mapper=localparent) +class LoadDeferredColumns(object): + """callable, serializable loader object used by DeferredColumnLoader""" + + def __init__(self, instance, key, keys, optimizing_statement=None): + self.instance = instance + self.key = key + self.keys = keys + self.optimizing_statement = optimizing_statement + + def __getstate__(self): + return {'instance':self.instance, 'key':self.key, 'keys':self.keys} + + def __setstate__(self, state): + self.instance = state['instance'] + self.key = state['key'] + self.keys = state['keys'] + self.optimizing_statement = None + + def __call__(self): + if not mapper.has_identity(self.instance): + return None + + localparent = mapper.object_mapper(self.instance, raiseerror=False) + + prop = localparent.get_property(self.key) + strategy = prop._get_strategy(DeferredColumnLoader) - return lazyload - -DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader) + if self.keys: + toload = self.keys + elif strategy.group: + toload = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==strategy.group] + else: + toload = [self.key] + + # narrow the keys down to just those which have no history + group = [k for k in toload if k in self.instance._state.unmodified] + + if strategy._should_log_debug: + strategy.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(self.instance, self.key), group and ','.join(group) or 'None')) + + session = sessionlib.object_session(self.instance) + if session is None: + raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key)) + + query = session.query(localparent) + if not self.optimizing_statement: + ident = self.instance._instance_key[1] + query._get(None, ident=ident, only_load_props=group, refresh_instance=self.instance._state) + else: + statement, params = self.optimizing_statement(self.instance) + query.from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=self.instance._state) + return attributes.ATTR_WAS_SET class DeferredOption(StrategizedOption): def __init__(self, key, defer=False): @@ -223,298 +217,358 @@ class DeferredOption(StrategizedOption): class UndeferGroupOption(MapperOption): def __init__(self, group): self.group = group - def process_query_context(self, context): - context.attributes[('undefer', self.group)] = True - - def process_selection_context(self, context): - context.attributes[('undefer', self.group)] = True + def process_query(self, query): + query._attributes[('undefer', self.group)] = True class AbstractRelationLoader(LoaderStrategy): def init(self): super(AbstractRelationLoader, self).init() - for attr in ['primaryjoin', 'secondaryjoin', 'secondary', 'foreign_keys', 'mapper', 'select_mapper', 'target', 'select_table', 'loads_polymorphic', 'uselist', 'cascade', 'attributeext', 'order_by', 'remote_side', 'polymorphic_primaryjoin', 'polymorphic_secondaryjoin', 'direction']: + for attr in ['primaryjoin', 'secondaryjoin', 'secondary', 'foreign_keys', 'mapper', 'target', 'table', 'uselist', 'cascade', 'attributeext', 'order_by', 'remote_side', 'direction']: setattr(self, attr, getattr(self.parent_property, attr)) self._should_log_debug = logging.is_debug_enabled(self.logger) def _init_instance_attribute(self, instance, callable_=None): - return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_) + if callable_: + instance._state.set_callable(self.key, callable_) + else: + instance._state.initialize(self.key) - def _register_attribute(self, class_, callable_=None): + def _register_attribute(self, class_, callable_=None, **kwargs): self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__)) - sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator) + sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, **kwargs) class NoLoader(AbstractRelationLoader): def init_class_attribute(self): + self.is_class_level = True self._register_attribute(self.parent.class_) def create_row_processor(self, selectcontext, mapper, row): - if not self.is_default or len(selectcontext.options): - def execute(instance, row, isnew, **flags): - if isnew: - if self._should_log_debug: - self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key)) - self._init_instance_attribute(instance) - return (execute, None) - else: - return (None, None) + def new_execute(instance, row, ispostselect, **flags): + if not ispostselect: + if self._should_log_debug: + self.logger.debug("initializing blank scalar/collection on %s" % mapperutil.attribute_str(instance, self.key)) + self._init_instance_attribute(instance) + return (new_execute, None, None) NoLoader.logger = logging.class_logger(NoLoader) class LazyLoader(AbstractRelationLoader): def init(self): super(LazyLoader, self).init() - (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self) + (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property) - self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere)) + self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.__lazywhere)) # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() #from sqlalchemy.orm import query - self.use_get = not self.uselist and self.mapper._get_clause.compare(self.lazywhere) + self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere) if self.use_get: self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads") def init_class_attribute(self): - self._register_attribute(self.parent.class_, callable_=lambda i: self.setup_loader(i)) + self.is_class_level = True + self._register_attribute(self.parent.class_, callable_=self.class_level_loader) + + def lazy_clause(self, instance, reverse_direction=False): + if instance is None: + return self._lazy_none_clause(reverse_direction) + + if not reverse_direction: + (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns) + else: + (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) + + def visit_bindparam(bindparam): + mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent + if bindparam.key in bind_to_col: + # use the "committed" (database) version to get query column values + # also its a deferred value; so that when used by Query, the committed value is used + # after an autoflush occurs + bindparam.value = lambda: mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key]) + return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam) + + def _lazy_none_clause(self, reverse_direction=False): + if not reverse_direction: + (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns) + else: + (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) - def setup_loader(self, instance, options=None): + def visit_binary(binary): + mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent + if isinstance(binary.left, expression._BindParamClause) and binary.left.key in bind_to_col: + # reverse order if the NULL is on the left side + binary.left = binary.right + binary.right = expression.null() + binary.operator = operators.is_ + elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in bind_to_col: + binary.right = expression.null() + binary.operator = operators.is_ + + return visitors.traverse(criterion, clone=True, visit_binary=visit_binary) + + def class_level_loader(self, instance, options=None, path=None): if not mapper.has_mapper(instance): return None - else: - prop = mapper.object_mapper(instance).get_property(self.key) - if prop is not self.parent_property: - return prop._get_strategy(LazyLoader).setup_loader(instance) - def lazyload(): - self.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance))) - params = {} - allparams = True - # if the instance wasnt loaded from the database, then it cannot lazy load - # child items. one reason for this is that a bi-directional relationship - # will not update properly, since bi-directional uses lazy loading functions - # in both directions, and this instance will not be present in the lazily-loaded - # results of the other objects since its not in the database - if not mapper.has_identity(instance): - return None - #print "setting up loader, lazywhere", str(self.lazywhere), "binds", self.lazybinds - for col, bind in self.lazybinds.iteritems(): - params[bind.key] = self.parent.get_attr_by_column(instance, col) - if params[bind.key] is None: - allparams = False - break - if not allparams: - return None + localparent = mapper.object_mapper(instance) - session = sessionlib.object_session(instance) - if session is None: - try: - session = mapper.object_mapper(instance).get_session() - except exceptions.InvalidRequestError: - raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) - - # if we have a simple straight-primary key load, use mapper.get() - # to possibly save a DB round trip - q = session.query(self.mapper) - if self.use_get: - ident = [] - # TODO: when options are added to allow switching between union-based and non-union - # based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper", - # probably via the query's own "mapper" property, and also use one of two "lazy" clauses, - # one against the "union" the other not - for primary_key in self.select_mapper.primary_key: - bind = self.lazyreverse[primary_key] - ident.append(params[bind.key]) - return q.get(ident) - elif self.order_by is not False: - q = q.order_by(self.order_by) - elif self.secondary is not None and self.secondary.default_order_by() is not None: - q = q.order_by(self.secondary.default_order_by()) - - if options: - q = q.options(*options) - q = q.filter(self.lazywhere).params(**params) - - result = q.all() - if self.uselist: - return result - else: - if len(result): - return result[0] - else: - return None - - if self.uselist: - return q.all() - else: - return q.first() + # adjust for the PropertyLoader associated with the instance + # not being our own PropertyLoader. This can occur when entity_name + # mappers are used to map different versions of the same PropertyLoader + # to the class. + prop = localparent.get_property(self.key) + if prop is not self.parent_property: + return prop._get_strategy(LazyLoader).setup_loader(instance) + + return LoadLazyAttribute(instance, self.key, options, path) - return lazyload + def setup_loader(self, instance, options=None, path=None): + return LoadLazyAttribute(instance, self.key, options, path) def create_row_processor(self, selectcontext, mapper, row): - if not self.is_default or len(selectcontext.options): - def execute(instance, row, isnew, **flags): - if isnew: + if not self.is_class_level or len(selectcontext.options): + def new_execute(instance, row, ispostselect, **flags): + if not ispostselect: if self._should_log_debug: self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, - # which will override the clareset_instance_attributess-level behavior - self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options)) - return (execute, None) + # which will override the class-level behavior + + self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options, selectcontext.query._current_path + selectcontext.path)) + return (new_execute, None, None) else: - def execute(instance, row, isnew, **flags): - if isnew: + def new_execute(instance, row, ispostselect, **flags): + if not ispostselect: if self._should_log_debug: self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) # we are the primary manager for this attribute on this class - reset its per-instance attribute state, # so that the class-level lazy loader is executed when next referenced on this instance. # this usually is not needed unless the constructor of the object referenced the attribute before we got # to load data into it. - sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) - return (execute, None) + instance._state.reset(self.key) + return (new_execute, None, None) - def _create_lazy_clause(cls, prop, reverse_direction=False, op='=='): - (primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side) - + def __create_lazy_clause(cls, prop, reverse_direction=False): binds = {} - reverse = {} - - def should_bind(targetcol, othercol): - if reverse_direction and not secondaryjoin: - return targetcol in remote_side - else: - return othercol in remote_side - - def find_column_in_expr(expr): - if not isinstance(expr, sql.ColumnElement): - return None - columns = [] - class FindColumnInColumnClause(sql.ClauseVisitor): - def visit_column(self, c): - columns.append(c) - FindColumnInColumnClause().traverse(expr) - return len(columns) and columns[0] or None + lookup = {} + equated_columns = {} + + if reverse_direction and not prop.secondaryjoin: + for l, r in prop.local_remote_pairs: + _list = lookup.setdefault(r, []) + _list.append((r, l)) + equated_columns[l] = r + else: + for l, r in prop.local_remote_pairs: + _list = lookup.setdefault(l, []) + _list.append((l, r)) + equated_columns[r] = l + + def col_to_bind(col): + if col in lookup: + for tobind, equated in lookup[col]: + if equated in binds: + return None + if col not in binds: + binds[col] = sql.bindparam(None, None, type_=col.type) + return binds[col] + return None + + lazywhere = prop.primaryjoin - def visit_binary(binary): - leftcol = find_column_in_expr(binary.left) - rightcol = find_column_in_expr(binary.right) - if leftcol is None or rightcol is None: - return - - if should_bind(leftcol, rightcol): - col = leftcol - binary.left = binds.setdefault(leftcol, - sql.bindparam(None, None, shortname=leftcol.name, type_=binary.right.type, unique=True)) - reverse[rightcol] = binds[col] - - # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1", - # which can happen in rare cases (test/orm/relationships.py RelationTest2) - if leftcol is not rightcol and should_bind(rightcol, leftcol): - col = rightcol - binary.right = binds.setdefault(rightcol, - sql.bindparam(None, None, shortname=rightcol.name, type_=binary.left.type, unique=True)) - reverse[leftcol] = binds[col] - - lazywhere = primaryjoin - li = mapperutil.BinaryVisitor(visit_binary) - - if not secondaryjoin or not reverse_direction: - lazywhere = li.traverse(lazywhere, clone=True) - - if secondaryjoin is not None: + if not prop.secondaryjoin or not reverse_direction: + lazywhere = visitors.traverse(lazywhere, before_clone=col_to_bind, clone=True) + + if prop.secondaryjoin is not None: + secondaryjoin = prop.secondaryjoin if reverse_direction: - secondaryjoin = li.traverse(secondaryjoin, clone=True) + secondaryjoin = visitors.traverse(secondaryjoin, before_clone=col_to_bind, clone=True) lazywhere = sql.and_(lazywhere, secondaryjoin) - return (lazywhere, binds, reverse) - _create_lazy_clause = classmethod(_create_lazy_clause) + + bind_to_col = dict([(binds[col].key, col) for col in binds]) + + return (lazywhere, bind_to_col, equated_columns) + __create_lazy_clause = classmethod(__create_lazy_clause) LazyLoader.logger = logging.class_logger(LazyLoader) +class LoadLazyAttribute(object): + """callable, serializable loader object used by LazyLoader""" + + def __init__(self, instance, key, options, path): + self.instance = instance + self.key = key + self.options = options + self.path = path + + def __getstate__(self): + return {'instance':self.instance, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)} + + def __setstate__(self, state): + self.instance = state['instance'] + self.key = state['key'] + self.options= state['options'] + self.path = deserialize_path(state['path']) + + def __call__(self): + instance = self.instance + + if not mapper.has_identity(instance): + return None + + instance_mapper = mapper.object_mapper(instance) + prop = instance_mapper.get_property(self.key) + strategy = prop._get_strategy(LazyLoader) + + if strategy._should_log_debug: + strategy.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance))) + + session = sessionlib.object_session(instance) + if session is None: + try: + session = instance_mapper.get_session() + except exceptions.InvalidRequestError: + raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) + + q = session.query(prop.mapper).autoflush(False) + if self.path: + q = q._with_current_path(self.path) + + # if we have a simple primary key load, use mapper.get() + # to possibly save a DB round trip + if strategy.use_get: + ident = [] + allnulls = True + for primary_key in prop.mapper.primary_key: + val = instance_mapper._get_committed_attr_by_column(instance, strategy._equated_columns[primary_key]) + allnulls = allnulls and val is None + ident.append(val) + if allnulls: + return None + if self.options: + q = q._conditional_options(*self.options) + return q.get(ident) + + if strategy.order_by is not False: + q = q.order_by(strategy.order_by) + elif strategy.secondary is not None and strategy.secondary.default_order_by() is not None: + q = q.order_by(strategy.secondary.default_order_by()) + + if self.options: + q = q._conditional_options(*self.options) + q = q.filter(strategy.lazy_clause(instance)) + + result = q.all() + if strategy.uselist: + return result + else: + if result: + return result[0] + else: + return None + class EagerLoader(AbstractRelationLoader): """Loads related objects inline with a parent query.""" def init(self): super(EagerLoader, self).init() - if self.is_default: - self.parent._eager_loaders.add(self.parent_property) - self.clauses = {} self.join_depth = self.parent_property.join_depth def init_class_attribute(self): + # class-level eager strategy; add the PropertyLoader + # to the parent's list of "eager loaders"; this tells the Query + # that eager loaders will be used in a normal query + self.parent._eager_loaders.add(self.parent_property) + + # initialize a lazy loader on the class level attribute self.parent_property._get_strategy(LazyLoader).init_class_attribute() def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs): """Add a left outer join to the statement thats being constructed.""" - # build a path as we setup the query. the format of this path - # matches that of interfaces.LoaderStack, and will be used in the - # row-loading phase to match up AliasedClause objects with the current - # LoaderStack position. - if parentclauses: - path = parentclauses.path + (self.parent.base_mapper(), self.key) - else: - path = (self.parent.base_mapper(), self.key) - + path = context.path - if self.join_depth: - if len(path) / 2 > self.join_depth: - return + # check for join_depth or basic recursion, + # if the current path was not explicitly stated as + # a desired "loaderstrategy" (i.e. via query.options()) + if ("loaderstrategy", path) not in context.attributes: + if self.join_depth: + if len(path) / 2 > self.join_depth: + return + else: + if self.mapper.base_mapper in path: + return + + if ("eager_row_processor", path) in context.attributes: + # if user defined eager_row_processor, that's contains_eager(). + # don't render LEFT OUTER JOIN, generate an AliasedClauses from + # the decorator (this is a hack here, cleaned up in 0.5) + cl = context.attributes[("eager_row_processor", path)] + if cl: + row = cl(None) + class ActsLikeAliasedClauses(object): + def aliased_column(self, col): + return row.map[col] + clauses = ActsLikeAliasedClauses() + else: + clauses = None else: - if self.mapper in path: + clauses = self.__create_eager_join(context, path, parentclauses, parentmapper, **kwargs) + if not clauses: return - - #print "CREATING EAGER PATH FOR", "->".join([str(s) for s in path]) - + + for value in self.mapper._iterate_polymorphic_properties(): + context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.mapper) + + def __create_eager_join(self, context, path, parentclauses, parentmapper, **kwargs): if parentmapper is None: localparent = context.mapper else: localparent = parentmapper - statement = context.statement - - if hasattr(statement, '_outerjoin'): - towrap = statement._outerjoin - elif isinstance(localparent.mapped_table, sql.Join): - towrap = localparent.mapped_table + if context.eager_joins: + towrap = context.eager_joins else: - # look for the mapper's selectable expressed within the current "from" criterion. - # this will locate the selectable inside of any containers it may be a part of (such - # as a join). if its inside of a join, we want to outer join on that join, not the - # selectable. - for fromclause in statement.froms: - if fromclause is localparent.mapped_table: - towrap = fromclause - break - elif isinstance(fromclause, sql.Join): - if localparent.mapped_table in sql_util.TableFinder(fromclause, include_aliases=True): - towrap = fromclause - break - else: - raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table)) - + towrap = context.from_clause + + # create AliasedClauses object to build up the eager query. this is cached after 1st creation. try: clauses = self.clauses[path] except KeyError: - clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.polymorphic_primaryjoin, self.parent_property.polymorphic_secondaryjoin, parentclauses) + clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.primaryjoin, self.parent_property.secondaryjoin, parentclauses) self.clauses[path] = clauses + + # place the "row_decorator" from the AliasedClauses into the QueryContext, where it will + # be picked up in create_row_processor() when results are fetched + context.attributes[("eager_row_processor", path)] = clauses.row_decorator if self.secondaryjoin is not None: - statement._outerjoin = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin) + context.eager_joins = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin) + + # TODO: check for "deferred" cols on parent/child tables here ? this would only be + # useful if the primary/secondaryjoin are against non-PK columns on the tables (and therefore might be deferred) + if self.order_by is False and self.secondary.default_order_by() is not None: - statement.append_order_by(*clauses.secondary.default_order_by()) + context.eager_order_by += clauses.secondary.default_order_by() else: - statement._outerjoin = towrap.outerjoin(clauses.alias, clauses.primaryjoin) + context.eager_joins = towrap.outerjoin(clauses.alias, clauses.primaryjoin) + # ensure all the cols on the parent side are actually in the + # columns clause (i.e. are not deferred), so that aliasing applied by the Query propagates + # those columns outward. This has the effect of "undefering" those columns. + for col in sql_util.find_columns(clauses.primaryjoin): + if localparent.mapped_table.c.contains_column(col): + context.primary_columns.append(col) + if self.order_by is False and clauses.alias.default_order_by() is not None: - statement.append_order_by(*clauses.alias.default_order_by()) + context.eager_order_by += clauses.alias.default_order_by() if clauses.order_by: - statement.append_order_by(*util.to_list(clauses.order_by)) + context.eager_order_by += util.to_list(clauses.order_by) - statement.append_from(statement._outerjoin) - - for value in self.select_mapper.iterate_properties: - value.setup(context, parentclauses=clauses, parentmapper=self.select_mapper) + return clauses def _create_row_decorator(self, selectcontext, row, path): """Create a *row decorating* function that will apply eager @@ -526,24 +580,14 @@ class EagerLoader(AbstractRelationLoader): #print "creating row decorator for path ", "->".join([str(s) for s in path]) - # check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option) - if selectcontext.attributes.has_key(("eager_row_processor", self.parent_property)): - # custom row decoration function, placed in the selectcontext by the - # contains_eager() mapper option - decorator = selectcontext.attributes[("eager_row_processor", self.parent_property)] + if ("eager_row_processor", path) in selectcontext.attributes: + decorator = selectcontext.attributes[("eager_row_processor", path)] if decorator is None: decorator = lambda row: row else: - try: - # decorate the row according to the stored AliasedClauses for this eager load - clauses = self.clauses[path] - decorator = clauses.row_decorator - except KeyError, k: - # no stored AliasedClauses: eager loading was not set up in the query and - # AliasedClauses never got initialized - if self._should_log_debug: - self.logger.debug("Could not locate aliased clauses for key: " + str(path)) - return None + if self._should_log_debug: + self.logger.debug("Could not locate aliased clauses for key: " + str(path)) + return None try: decorated_row = decorator(row) @@ -558,16 +602,13 @@ class EagerLoader(AbstractRelationLoader): return None def create_row_processor(self, selectcontext, mapper, row): - selectcontext.stack.push_property(self.key) - path = selectcontext.stack.snapshot() - row_decorator = self._create_row_decorator(selectcontext, row, path) + row_decorator = self._create_row_decorator(selectcontext, row, selectcontext.path) + pathstr = ','.join([str(x) for x in selectcontext.path]) if row_decorator is not None: def execute(instance, row, isnew, **flags): decorated_row = row_decorator(row) - selectcontext.stack.push_property(self.key) - if not self.uselist: if self._should_log_debug: self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key)) @@ -576,61 +617,67 @@ class EagerLoader(AbstractRelationLoader): # parent object, bypassing InstrumentedAttribute # event handlers. # - # FIXME: instead of... - sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None)) - # bypass and set directly: - #instance.__dict__[self.key] = ... + instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None) else: # call _instance on the row, even though the object has been created, # so that we further descend into properties self.mapper._instance(selectcontext, decorated_row, None) else: - if isnew: + if isnew or self.key not in instance._state.appenders: + # appender_key can be absent from selectcontext.attributes with isnew=False + # when self-referential eager loading is used; the same instance may be present + # in two distinct sets of result columns + if self._should_log_debug: self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key)) - collection = sessionlib.attribute_manager.init_collection(instance, self.key) + collection = attributes.init_collection(instance, self.key) appender = util.UniqueAppender(collection, 'append_without_event') - # store it in the "scratch" area, which is local to this load operation. - selectcontext.attributes[(instance, self.key)] = appender - result_list = selectcontext.attributes[(instance, self.key)] + instance._state.appenders[self.key] = appender + + result_list = instance._state.appenders[self.key] if self._should_log_debug: self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) - - self.select_mapper._instance(selectcontext, decorated_row, result_list) - selectcontext.stack.pop() + + self.mapper._instance(selectcontext, decorated_row, result_list) - selectcontext.stack.pop() - return (execute, None) + if self._should_log_debug: + self.logger.debug("Returning eager instance loader for %s" % str(self)) + + return (execute, execute, None) else: - self.logger.debug("eager loader %s degrading to lazy loader" % str(self)) - selectcontext.stack.pop() + if self._should_log_debug: + self.logger.debug("eager loader %s degrading to lazy loader" % str(self)) return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row) - - + def __str__(self): return str(self.parent) + "." + self.key EagerLoader.logger = logging.class_logger(EagerLoader) class EagerLazyOption(StrategizedOption): - def __init__(self, key, lazy=True, chained=False): - super(EagerLazyOption, self).__init__(key) + def __init__(self, key, lazy=True, chained=False, mapper=None): + super(EagerLazyOption, self).__init__(key, mapper) self.lazy = lazy self.chained = chained def is_chained(self): return not self.lazy and self.chained - def process_query_property(self, context, properties): + def process_query_property(self, query, paths): if self.lazy: - if properties[-1] in context.eager_loaders: - context.eager_loaders.remove(properties[-1]) + if paths[-1] in query._eager_loaders: + query._eager_loaders = query._eager_loaders.difference(util.Set([paths[-1]])) else: - for prop in properties: - context.eager_loaders.add(prop) - super(EagerLazyOption, self).process_query_property(context, properties) + if not self.chained: + paths = [paths[-1]] + res = util.Set() + for path in paths: + if len(path) - len(query._current_path) == 2: + res.add(path) + query._eager_loaders = query._eager_loaders.union(res) + super(EagerLazyOption, self).process_query_property(query, paths) def get_strategy_class(self): if self.lazy: @@ -642,39 +689,22 @@ class EagerLazyOption(StrategizedOption): EagerLazyOption.logger = logging.class_logger(EagerLazyOption) -# TODO: enable FetchMode option. currently -# this class does nothing. will require Query -# to swich between using its "polymorphic" selectable -# and its regular selectable in order to make decisions -# (therefore might require that FetchModeOperation is performed -# only as the first operation on a Query.) -class FetchModeOption(PropertyOption): - def __init__(self, key, type): - super(FetchModeOption, self).__init__(key) - if type not in ('join', 'select'): - raise exceptions.ArgumentError("Fetchmode must be one of 'join' or 'select'") - self.type = type - - def process_selection_property(self, context, properties): - context.attributes[('fetchmode', properties[-1])] = self.type - class RowDecorateOption(PropertyOption): def __init__(self, key, decorator=None, alias=None): super(RowDecorateOption, self).__init__(key) self.decorator = decorator self.alias = alias - def process_selection_property(self, context, properties): + def process_query_property(self, query, paths): if self.alias is not None and self.decorator is None: + (mapper, propname) = paths[-1][-2:] + + prop = mapper.get_property(propname, resolve_synonyms=True) if isinstance(self.alias, basestring): - self.alias = properties[-1].target.alias(self.alias) - def decorate(row): - d = {} - for c in properties[-1].target.columns: - d[c] = row[self.alias.corresponding_column(c)] - return d - self.decorator = decorate - context.attributes[("eager_row_processor", properties[-1])] = self.decorator + self.alias = prop.target.alias(self.alias) + + self.decorator = mapperutil.create_row_adapter(self.alias) + query._attributes[("eager_row_processor", paths[-1])] = self.decorator RowDecorateOption.logger = logging.class_logger(RowDecorateOption) diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index cf48202b0f..39a7b5044c 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -1,152 +1,86 @@ # mapper/sync.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Contains the ClauseSynchronizer class, which is used to map -attributes between two objects in a manner corresponding to a SQL -clause that compares column values. +"""private module containing functions used for copying data between instances +based on join conditions. """ -from sqlalchemy import sql, schema, exceptions +from sqlalchemy import schema, exceptions, util +from sqlalchemy.sql import visitors, operators, util as sqlutil from sqlalchemy import logging from sqlalchemy.orm import util as mapperutil -import operator +from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY # legacy -ONETOMANY = 0 -MANYTOONE = 1 -MANYTOMANY = 2 - -class ClauseSynchronizer(object): - """Given a SQL clause, usually a series of one or more binary - expressions between columns, and a set of 'source' and - 'destination' mappers, compiles a set of SyncRules corresponding - to that information. - - The ClauseSynchronizer can then be executed given a set of - parent/child objects or destination dictionary, which will iterate - through each of its SyncRules and execute them. Each SyncRule - will copy the value of a single attribute from the parent to the - child, corresponding to the pair of columns in a particular binary - expression, using the source and destination mappers to map those - two columns to object attributes within parent and child. - """ - - def __init__(self, parent_mapper, child_mapper, direction): - self.parent_mapper = parent_mapper - self.child_mapper = child_mapper - self.direction = direction - self.syncrules = [] - - def compile(self, sqlclause, foreign_keys=None, issecondary=None): - def compile_binary(binary): - """Assemble a SyncRule given a single binary condition.""" - - if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): - return - - source_column = None - dest_column = None - - if foreign_keys is None: - if binary.left.table == binary.right.table: - raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync") - - if binary.left in [f.column for f in binary.right.foreign_keys]: - dest_column = binary.right - source_column = binary.left - elif binary.right in [f.column for f in binary.left.foreign_keys]: - dest_column = binary.left - source_column = binary.right - else: - if binary.left in foreign_keys: - source_column=binary.right - dest_column = binary.left - elif binary.right in foreign_keys: - source_column = binary.left - dest_column = binary.right - - if source_column and dest_column: - if self.direction == ONETOMANY: - self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper)) - elif self.direction == MANYTOONE: - self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper)) - else: - if not issecondary: - self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper, issecondary=issecondary)) - else: - self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary)) - - rules_added = len(self.syncrules) - BinaryVisitor(compile_binary).traverse(sqlclause) - if len(self.syncrules) == rules_added: - raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause)) - - def dest_columns(self): - return [r.dest_column for r in self.syncrules if r.dest_column is not None] - - def execute(self, source, dest, obj=None, child=None, clearkeys=None): - for rule in self.syncrules: - rule.execute(source, dest, obj, child, clearkeys) - -class SyncRule(object): - """An instruction indicating how to populate the objects on each - side of a relationship. - - In other words, if table1 column A is joined against table2 column - B, and we are a one-to-many from table1 to table2, a syncrule - would say *take the A attribute from object1 and assign it to the - B attribute on object2*. - - A rule contains the source mapper, the source column, destination - column, destination mapper in the case of a one/many relationship, - and the integer direction of this mapper relative to the - association in the case of a many to many relationship. - """ - - def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None): - self.source_mapper = source_mapper - self.source_column = source_column - self.issecondary = issecondary - self.dest_mapper = dest_mapper - self.dest_column = dest_column - - #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper - - def dest_primary_key(self): +def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: try: - return self._dest_primary_key - except AttributeError: - self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper.pks_by_table[self.dest_column.table] - return self._dest_primary_key - - def execute(self, source, dest, obj, child, clearkeys): - if source is None: - if self.issecondary is False: - source = obj - elif self.issecondary is True: - source = child - if clearkeys or source is None: - value = None - clearkeys = True - else: - value = self.source_mapper.get_attr_by_column(source, self.source_column) - if isinstance(dest, dict): - dest[self.dest_column.key] = value - else: - if clearkeys and self.dest_primary_key(): - raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.instance_str(dest))) - - if logging.is_debug_enabled(self.logger): - self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value)) - self.dest_mapper.set_attr_by_column(dest, self.dest_column, value) - -SyncRule.logger = logging.class_logger(SyncRule) + value = source_mapper._get_state_attr_by_column(source, l) + except exceptions.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) -class BinaryVisitor(sql.ClauseVisitor): - def __init__(self, func): - self.func = func + try: + dest_mapper._set_state_attr_by_column(dest, r, value) + except exceptions.UnmappedColumnError: + self._raise_col_to_prop(True, source_mapper, l, dest_mapper, r) + +def clear(dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + if r.primary_key: + raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest))) + try: + dest_mapper._set_state_attr_by_column(dest, r, None) + except exceptions.UnmappedColumnError: + _raise_col_to_prop(True, None, l, dest_mapper, r) - def visit_binary(self, binary): - self.func(binary) +def update(source, source_mapper, dest, old_prefix, synchronize_pairs): + for l, r in synchronize_pairs: + try: + oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l) + value = source_mapper._get_state_attr_by_column(source, l) + except exceptions.UnmappedColumnError: + self._raise_col_to_prop(False, source_mapper, l, None, r) + dest[r.key] = value + dest[old_prefix + r.key] = oldvalue + +def populate_dict(source, source_mapper, dict_, synchronize_pairs): + for l, r in synchronize_pairs: + try: + value = source_mapper._get_state_attr_by_column(source, l) + except exceptions.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) + + dict_[r.key] = value + +def source_changes(uowcommit, source, source_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + try: + prop = source_mapper._get_col_to_prop(l) + except exceptions.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) + (added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True) + if added and deleted: + return True + else: + return False + +def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + try: + prop = dest_mapper._get_col_to_prop(r) + except exceptions.UnmappedColumnError: + _raise_col_to_prop(True, None, l, dest_mapper, r) + (added, unchanged, deleted) = uowcommit.get_attribute_history(dest, prop.key, passive=True) + if added and deleted: + return True + else: + return False + +def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column): + if isdest: + raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper)) + else: + raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column)) + diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index f59042810a..66b68770d6 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -1,5 +1,5 @@ # orm/unitofwork.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -19,65 +19,71 @@ new, dirty, or deleted and provides the capability to flush all those changes at once. """ +import StringIO, weakref from sqlalchemy import util, logging, topological, exceptions from sqlalchemy.orm import attributes, interfaces from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.mapper import object_mapper -import StringIO -import weakref +from sqlalchemy.orm.mapper import object_mapper, _state_mapper, has_identity # Load lazily object_session = None class UOWEventHandler(interfaces.AttributeExtension): - """An event handler added to all class attributes which handles - session operations. + """An event handler added to all relation attributes which handles + session cascade operations. """ - def __init__(self, key, class_, cascade=None): + def __init__(self, key, class_, cascade): self.key = key self.class_ = class_ self.cascade = cascade + + def _target_mapper(self, obj): + prop = object_mapper(obj).get_property(self.key) + return prop.mapper def append(self, obj, item, initiator): # process "save_update" cascade rules for when an instance is appended to the list of another instance sess = object_session(obj) - if sess is not None: - if self.cascade is not None and self.cascade.save_update and item not in sess: - mapper = object_mapper(obj) - prop = mapper.get_property(self.key) - ename = prop.mapper.entity_name - sess.save_or_update(item, entity_name=ename) + if sess: + if self.cascade.save_update and item not in sess: + sess.save_or_update(item, entity_name=self._target_mapper(obj).entity_name) def remove(self, obj, item, initiator): - # currently no cascade rules for removing an item from a list - # (i.e. it stays in the Session) - pass + sess = object_session(obj) + if sess: + # expunge pending orphans + if self.cascade.delete_orphan and item in sess.new: + if self._target_mapper(obj)._is_orphan(item): + sess.expunge(item) def set(self, obj, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance is attached to another instance + if oldvalue is newvalue: + return sess = object_session(obj) - if sess is not None: - if newvalue is not None and self.cascade is not None and self.cascade.save_update and newvalue not in sess: - mapper = object_mapper(obj) - prop = mapper.get_property(self.key) - ename = prop.mapper.entity_name - sess.save_or_update(newvalue, entity_name=ename) + if sess: + if newvalue is not None and self.cascade.save_update and newvalue not in sess: + sess.save_or_update(newvalue, entity_name=self._target_mapper(obj).entity_name) + if self.cascade.delete_orphan and oldvalue in sess.new: + sess.expunge(oldvalue) -class UOWAttributeManager(attributes.AttributeManager): - """Override ``AttributeManager`` to provide the ``UOWProperty`` - instance for all ``InstrumentedAttributes``. +def register_attribute(class_, key, *args, **kwargs): + """overrides attributes.register_attribute() to add UOW event handlers + to new InstrumentedAttributes. """ - - def create_prop(self, class_, key, uselist, callable_, typecallable, - cascade=None, extension=None, **kwargs): - extension = util.to_list(extension or []) + + cascade = kwargs.pop('cascade', None) + useobject = kwargs.get('useobject', False) + if useobject: + # for object-holding attributes, instrument UOWEventHandler + # to process per-attribute cascades + extension = util.to_list(kwargs.pop('extension', None) or []) extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) - - return super(UOWAttributeManager, self).create_prop( - class_, key, uselist, callable_, typecallable, - extension=extension, **kwargs) + kwargs['extension'] = extension + return attributes.register_attribute(class_, key, *args, **kwargs) + class UnitOfWork(object): @@ -88,129 +94,177 @@ class UnitOfWork(object): operation. """ - def __init__(self, identity_map=None, weak_identity_map=False): - if identity_map is not None: - self.identity_map = identity_map + def __init__(self, session): + if session.weak_identity_map: + self.identity_map = attributes.WeakInstanceDict() else: - if weak_identity_map: - self.identity_map = weakref.WeakValueDictionary() - else: - self.identity_map = {} + self.identity_map = attributes.StrongInstanceDict() - self.new = util.Set() #OrderedSet() - self.deleted = util.Set() - self.logger = logging.instance_logger(self) + self.new = {} # InstanceState->object, strong refs object + self.deleted = {} # same + self.logger = logging.instance_logger(self, echoflag=session.echo_uow) - echo = logging.echo_property() + def _remove_deleted(self, state): + if '_instance_key' in state.dict: + del self.identity_map[state.dict['_instance_key']] + self.deleted.pop(state, None) + self.new.pop(state, None) - def _remove_deleted(self, obj): - if hasattr(obj, "_instance_key"): - del self.identity_map[obj._instance_key] - try: - self.deleted.remove(obj) - except KeyError: - pass - try: - self.new.remove(obj) - except KeyError: - pass - - def _validate_obj(self, obj): - if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \ - (not hasattr(obj, '_instance_key') and obj not in self.new): - raise exceptions.InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj)) - - def _is_valid(self, obj): - if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \ - (not hasattr(obj, '_instance_key') and obj not in self.new): - return False + def _is_valid(self, state): + if '_instance_key' in state.dict: + return state.dict['_instance_key'] in self.identity_map else: - return True + return state in self.new + + def _register_clean(self, state): + """register the given object as 'clean' (i.e. persistent) within this unit of work, after + a save operation has taken place.""" - def register_clean(self, obj): - """register the given object as 'clean' (i.e. persistent) within this unit of work.""" + mapper = _state_mapper(state) + instance_key = mapper._identity_key_from_state(state) - if obj in self.new: - self.new.remove(obj) - if not hasattr(obj, '_instance_key'): - mapper = object_mapper(obj) - obj._instance_key = mapper.identity_key_from_instance(obj) - if hasattr(obj, '_sa_insert_order'): - delattr(obj, '_sa_insert_order') - self.identity_map[obj._instance_key] = obj - attribute_manager.commit(obj) + if '_instance_key' not in state.dict: + state.dict['_instance_key'] = instance_key + + elif state.dict['_instance_key'] != instance_key: + # primary key switch + del self.identity_map[state.dict['_instance_key']] + state.dict['_instance_key'] = instance_key + + if hasattr(state, 'insert_order'): + delattr(state, 'insert_order') + + o = state.obj() + # prevent against last minute dereferences of the object + # TODO: identify a code path where state.obj() is None + if o is not None: + self.identity_map[state.dict['_instance_key']] = o + state.commit_all() + + # remove from new last, might be the last strong ref + self.new.pop(state, None) def register_new(self, obj): """register the given object as 'new' (i.e. unsaved) within this unit of work.""" if hasattr(obj, '_instance_key'): raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj)) - if obj not in self.new: - self.new.add(obj) - obj._sa_insert_order = len(self.new) + if obj._state not in self.new: + self.new[obj._state] = obj + obj._state.insert_order = len(self.new) def register_deleted(self, obj): """register the given persistent object as 'to be deleted' within this unit of work.""" - if obj not in self.deleted: - self._validate_obj(obj) - self.deleted.add(obj) + self.deleted[obj._state] = obj def locate_dirty(self): """return a set of all persistent instances within this unit of work which either contain changes or are marked as deleted. """ - return util.Set([x for x in self.identity_map.values() if x not in self.deleted and attribute_manager.is_modified(x)]) + # a little bit of inlining for speed + return util.IdentitySet([x for x in self.identity_map.values() + if x._state not in self.deleted + and ( + x._state.modified + or (x.__class__._class_state.has_mutable_scalars and x._state.is_modified()) + ) + ]) def flush(self, session, objects=None): """create a dependency tree of all pending SQL operations within this unit of work and execute.""" - - # this context will track all the objects we want to save/update/delete, - # and organize a hierarchical dependency structure. it also handles - # communication with the mappers and relationships to fire off SQL - # and synchronize attributes between related objects. - echo = logging.is_info_enabled(self.logger) + dirty = [x for x in self.identity_map.all_states() + if x.modified + or (x.class_._class_state.has_mutable_scalars and x.is_modified()) + ] + + if not dirty and not self.deleted and not self.new: + return + + deleted = util.Set(self.deleted) + new = util.Set(self.new) + + dirty = util.Set(dirty).difference(deleted) + flush_context = UOWTransaction(self, session) + if session.extension is not None: + session.extension.before_flush(session, flush_context, objects) + # create the set of all objects we want to operate upon - if objects is not None: + if objects: # specific list passed in - objset = util.Set(objects) + objset = util.Set([o._state for o in objects]) else: # or just everything - objset = util.Set(self.identity_map.values()).union(self.new) - - # detect persistent objects that have changes - dirty = self.locate_dirty() - + objset = util.Set(self.identity_map.all_states()).union(new) + # store objects whose fate has been decided processed = util.Set() # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted. - for obj in self.new.union(dirty).intersection(objset).difference(self.deleted): - if obj in processed: + for state in new.union(dirty).intersection(objset).difference(deleted): + if state in processed: continue - flush_context.register_object(obj, isdelete=object_mapper(obj)._is_orphan(obj)) - processed.add(obj) + obj = state.obj() + is_orphan = _state_mapper(state)._is_orphan(obj) + if is_orphan and not has_identity(obj): + raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" % + ( + obj, + ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans]) + )) + flush_context.register_object(state, isdelete=is_orphan) + processed.add(state) # put all remaining deletes into the flush context. - for obj in self.deleted.intersection(objset).difference(processed): - flush_context.register_object(obj, isdelete=True) + for state in deleted.intersection(objset).difference(processed): + flush_context.register_object(state, isdelete=True) + if len(flush_context.tasks) == 0: + return + session.create_transaction(autoflush=False) flush_context.transaction = session.transaction try: flush_context.execute() + + if session.extension is not None: + session.extension.after_flush(session, flush_context) + session.commit() except: session.rollback() raise - session.commit() flush_context.post_exec() + if session.extension is not None: + session.extension.after_flush_postexec(session, flush_context) + + def prune_identity_map(self): + """Removes unreferenced instances cached in a strong-referencing identity map. + + Note that this method is only meaningful if "weak_identity_map" + on the parent Session is set to False and therefore this UnitOfWork's + identity map is a regular dictionary + + Removes any object in the identity map that is not referenced + in user code or scheduled for a unit of work operation. Returns + the number of objects pruned. + """ + + if isinstance(self.identity_map, attributes.WeakInstanceDict): + return 0 + ref_count = len(self.identity_map) + dirty = self.locate_dirty() + keepers = weakref.WeakValueDictionary(self.identity_map) + self.identity_map.clear() + self.identity_map.update(keepers) + return ref_count - len(self.identity_map) + class UOWTransaction(object): """Handles the details of organizing and executing transaction tasks during a UnitOfWork object's flush() operation. @@ -237,55 +291,71 @@ class UOWTransaction(object): # information. self.attributes = {} - self.logger = logging.instance_logger(self) - self.echo = uow.echo - - echo = logging.echo_property() - - def register_object(self, obj, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs): - """Add an object to this ``UOWTransaction`` to be updated in the database. - - This operation has the combined effect of locating/creating an appropriate - ``UOWTask`` object, and calling its ``append()`` method which then locates/creates - an appropriate ``UOWTaskElement`` object. - """ + self.logger = logging.instance_logger(self, echoflag=session.echo_uow) + + def get_attribute_history(self, state, key, passive=True): + hashkey = ("history", state, key) + + # cache the objects, not the states; the strong reference here + # prevents newly loaded objects from being dereferenced during the + # flush process + if hashkey in self.attributes: + (added, unchanged, deleted, cached_passive) = self.attributes[hashkey] + # if the cached lookup was "passive" and now we want non-passive, do a non-passive + # lookup and re-cache + if cached_passive and not passive: + (added, unchanged, deleted) = attributes.get_history(state, key, passive=False) + self.attributes[hashkey] = (added, unchanged, deleted, passive) + else: + (added, unchanged, deleted) = attributes.get_history(state, key, passive=passive) + self.attributes[hashkey] = (added, unchanged, deleted, passive) - #print "REGISTER", repr(obj), repr(getattr(obj, '_instance_key', None)), str(isdelete), str(listonly) + if added is None: + return (added, unchanged, deleted) + else: + return ( + [getattr(c, '_state', c) for c in added], + [getattr(c, '_state', c) for c in unchanged], + [getattr(c, '_state', c) for c in deleted], + ) + + def register_object(self, state, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs): # if object is not in the overall session, do nothing - if not self.uow._is_valid(obj): - if logging.is_debug_enabled(self.logger): - self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.instance_str(obj))) + if not self.uow._is_valid(state): + if self._should_log_debug: + self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.state_str(state))) return - if logging.is_debug_enabled(self.logger): - self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.instance_str(obj), isdelete, listonly, postupdate)) + if self._should_log_debug: + self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.state_str(state), isdelete, listonly, postupdate)) - mapper = object_mapper(obj) + mapper = _state_mapper(state) + task = self.get_task_by_mapper(mapper) if postupdate: - task.append_postupdate(obj, post_update_cols) - return - - task.append(obj, listonly, isdelete=isdelete, **kwargs) + task.append_postupdate(state, post_update_cols) + else: + task.append(state, listonly, isdelete=isdelete, **kwargs) - def unregister_object(self, obj): - """remove an object from its parent UOWTask. + def set_row_switch(self, state): + """mark a deleted object as a 'row switch'. - called by mapper.save_obj() when an 'identity switch' is detected, so that - no further operations occur upon the instance.""" - mapper = object_mapper(obj) + this indicates that an INSERT statement elsewhere corresponds to this DELETE; + the INSERT is converted to an UPDATE and the DELETE does not occur. + """ + mapper = _state_mapper(state) task = self.get_task_by_mapper(mapper) - if obj in task._objects: - task.delete(obj) - - def is_deleted(self, obj): - """return true if the given object is marked as deleted within this UOWTransaction.""" + taskelement = task._objects[state] + taskelement.isdelete = "rowswitch" - mapper = object_mapper(obj) + def is_deleted(self, state): + """return true if the given state is marked as deleted within this UOWTransaction.""" + + mapper = _state_mapper(state) task = self.get_task_by_mapper(mapper) - return task.is_deleted(obj) - + return task.is_deleted(state) + def get_task_by_mapper(self, mapper, dontcreate=False): """return UOWTask element corresponding to the given mapper. @@ -298,18 +368,16 @@ class UOWTransaction(object): if dontcreate: return None - base_mapper = mapper.base_mapper() + base_mapper = mapper.base_mapper if base_mapper in self.tasks: base_task = self.tasks[base_mapper] else: - base_task = UOWTask(self, base_mapper) - self.tasks[base_mapper] = base_task - base_mapper.register_dependencies(self) + self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper) + base_mapper._register_dependencies(self) if mapper not in self.tasks: - task = UOWTask(self, mapper, base_task=base_task) - self.tasks[mapper] = task - mapper.register_dependencies(self) + self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task) + mapper._register_dependencies(self) else: task = self.tasks[mapper] @@ -323,55 +391,39 @@ class UOWTransaction(object): by another. """ - # correct for primary mapper (the mapper offcially associated with the class) + # correct for primary mapper # also convert to the "base mapper", the parentmost task at the top of an inheritance chain # dependency sorting is done via non-inheriting mappers only, dependencies between mappers # in the same inheritance chain is done at the per-object level - mapper = mapper.primary_mapper().base_mapper() - dependency = dependency.primary_mapper().base_mapper() + mapper = mapper.primary_mapper().base_mapper + dependency = dependency.primary_mapper().base_mapper self.dependencies.add((mapper, dependency)) def register_processor(self, mapper, processor, mapperfrom): - """register a dependency processor object, corresponding to dependencies between + """register a dependency processor, corresponding to dependencies between the two given mappers. - In reality, the processor is an instance of ``dependency.DependencyProcessor`` - and is registered as a result of the ``mapper.register_dependencies()`` call in - ``get_task_by_mapper()``. - - The dependency processor supports the methods ``preprocess_dependencies()`` and - ``process_dependencies()``, which - perform operations on a list of instances that have a dependency relationship - with some other instance. The operations include adding items to the UOW - corresponding to some cascade operations, issuing inserts/deletes on - association tables, and synchronzing foreign key values between related objects - before the dependent object is operated upon at the SQL level. """ - # when the task from "mapper" executes, take the objects from the task corresponding - # to "mapperfrom"'s list of save/delete objects, and send them to "processor" - # for dependency processing - - #print "registerprocessor", str(mapper), repr(processor), repr(processor.key), str(mapperfrom) - - # correct for primary mapper (the mapper offcially associated with the class) + # correct for primary mapper mapper = mapper.primary_mapper() mapperfrom = mapperfrom.primary_mapper() task = self.get_task_by_mapper(mapper) targettask = self.get_task_by_mapper(mapperfrom) up = UOWDependencyProcessor(processor, targettask) - task._dependencies.add(up) + task.dependencies.add(up) def execute(self): """Execute this UOWTransaction. - This will organize all collected UOWTasks into a toplogically-sorted - dependency tree, which is then traversed using the traversal scheme + This will organize all collected UOWTasks into a dependency-sorted + list which is then traversed using the traversal scheme encoded in the UOWExecutor class. Operations to mappers and dependency processors are fired off in order to issue SQL to the database and - to maintain instance state during the execution.""" + synchronize instance attributes with database values and related + foreign key values.""" # pre-execute dependency processors. this process may # result in new tasks, objects and/or dependency processors being added, @@ -387,16 +439,19 @@ class UOWTransaction(object): if not ret: break - head = self._sort_dependencies() - if self.echo: - if head is None: - self.logger.info("Task dump: None") - else: - self.logger.info("Task dump:\n" + head.dump()) - if head is not None: - UOWExecutor().execute(self, head) - self.logger.info("Execute Complete") + tasks = self._sort_dependencies() + if self._should_log_info: + self.logger.info("Task dump:\n" + self._dump(tasks)) + UOWExecutor().execute(self, tasks) + if self._should_log_info: + self.logger.info("Execute Complete") + def _dump(self, tasks): + buf = StringIO.StringIO() + import uowdumper + uowdumper.UOWDumper(tasks, buf) + return buf.getvalue() + def post_exec(self): """mark processed objects as clean / deleted after a successful flush(). @@ -406,49 +461,31 @@ class UOWTransaction(object): for task in self.tasks.values(): for elem in task.elements: - if elem.obj is None: + if elem.state is None: continue if elem.isdelete: - self.uow._remove_deleted(elem.obj) + self.uow._remove_deleted(elem.state) else: - self.uow.register_clean(elem.obj) + self.uow._register_clean(elem.state) def _sort_dependencies(self): - """Create a hierarchical tree of dependent UOWTask instances. - - The root UOWTask is returned. - - Cyclical relationships - within the toplogical sort are further broken down into new - temporary UOWTask insances which represent smaller sub-groups of objects - that would normally belong to a single UOWTask. - - """ - - def sort_hier(node): - if node is None: - return None - task = self.get_task_by_mapper(node.item) - if node.cycles is not None: - tasks = [] - for n in node.cycles: - tasks.append(self.get_task_by_mapper(n.item)) - task = task._sort_circular_dependencies(self, tasks) - for child in node.children: - t = sort_hier(child) - if t is not None: - task.childtasks.append(t) - return task + nodes = topological.sort_with_cycles(self.dependencies, + [t.mapper for t in self.tasks.values() if t.base_task is t] + ) + + ret = [] + for item, cycles in nodes: + task = self.get_task_by_mapper(item) + if cycles: + for t in task._sort_circular_dependencies(self, [self.get_task_by_mapper(i) for i in cycles]): + ret.append(t) + else: + ret.append(task) - # get list of base mappers - mappers = [t.mapper for t in self.tasks.values() if t.base_task is t] - head = topological.QueueDependencySorter(self.dependencies, mappers).sort(allow_all_cycles=True) - if logging.is_debug_enabled(self.logger): + if self._should_log_debug: self.logger.debug("Dependent tuples:\n" + "\n".join(["(%s->%s)" % (d[0].class_.__name__, d[1].class_.__name__) for d in self.dependencies])) - self.logger.debug("Dependency sort:\n"+ str(head)) - task = sort_hier(head) - return task - + self.logger.debug("Dependency sort:\n"+ str(ret)) + return ret class UOWTask(object): """Represents all of the objects in the UOWTransaction which correspond to @@ -457,7 +494,6 @@ class UOWTask(object): """ def __init__(self, uowtransaction, mapper, base_task=None): - # the transaction owning this UOWTask self.uowtransaction = uowtransaction # base_task is the UOWTask which represents the "base mapper" @@ -476,31 +512,11 @@ class UOWTask(object): # the Mapper which this UOWTask corresponds to self.mapper = mapper - # a dictionary mapping object instances to a corresponding UOWTaskElement. - # Each UOWTaskElement represents one object instance which is to be saved or - # deleted by this UOWTask's Mapper. - # in the case of the row-based "cyclical sort", the UOWTaskElement may - # also reference further UOWTasks which are dependent on that UOWTaskElement. + # mapping of InstanceState -> UOWTaskElement self._objects = {} - # a set of UOWDependencyProcessor instances, which are executed after saves and - # before deletes, to synchronize data between dependent objects as well as to - # ensure that relationship cascades populate the flush() process with all - # appropriate objects. - self._dependencies = util.Set() - - # a list of UOWTasks which are sub-nodes to this UOWTask. this list - # is populated during the dependency sorting operation. - self.childtasks = [] - - # a list of UOWDependencyProcessor instances - # which derive from the UOWDependencyProcessor instances present in a - # corresponding UOWTask's "_dependencies" set. This collection is populated - # during a row-based cyclical sorting operation and only corresponds to - # new UOWTask instances created during this operation, which are also local - # to the dependency graph (i.e. they are not present in the get_task_by_mapper() - # collection). - self._cyclical_dependencies = util.Set() + self.dependencies = util.Set() + self.cyclical_dependencies = util.Set() def polymorphic_tasks(self): """return an iterator of UOWTask objects corresponding to the inheritance sequence @@ -531,65 +547,29 @@ class UOWTask(object): t = self.base_task._inheriting_tasks.get(mapper, None) if t is not None: yield t - + def is_empty(self): """return True if this UOWTask is 'empty', meaning it has no child items. - + used only for debugging output. """ - - return len(self._objects) == 0 and len(self._dependencies) == 0 and len(self.childtasks) == 0 - def append(self, obj, listonly = False, childtask = None, isdelete = False): - """Append an object to this task to be persisted or deleted. - - The actual record added to the ``UOWTask`` is a ``UOWTaskElement`` object - corresponding to the given instance. If a corresponding ``UOWTaskElement`` already - exists within this ``UOWTask``, its state is updated with the given - keyword arguments as appropriate. - - 'isdelete' when True indicates the operation will be a "delete" - operation (i.e. DELETE), otherwise is a "save" operation (i.e. INSERT/UPDATE). - a ``UOWTaskElement`` marked as "save" which receives the "isdelete" flag will - be marked as deleted, but the reverse operation does not apply (i.e. goes from - "delete" to being "not delete"). - - `listonly` indicates that the object does not require a delete - or save operation, but does require dependency operations to be - executed. For example, adding a child object to a parent via a - one-to-many relationship requires that a ``OneToManyDP`` object - corresponding to the parent's mapper synchronize the instance's primary key - value into the foreign key attribute of the child object, even though - no changes need be persisted on the parent. - - a listonly object may be "upgraded" to require a save/delete operation - by a subsequent append() of the same object instance with the `listonly` - flag set to False. once the flag is set to false, it stays that way - on the ``UOWTaskElement``. + return not self._objects and not self.dependencies + + def append(self, state, listonly=False, isdelete=False): + if state not in self._objects: + self._objects[state] = rec = UOWTaskElement(state) + else: + rec = self._objects[state] - `childtask` is an optional ``UOWTask`` element represending operations which - are dependent on the parent ``UOWTaskElement``. This flag is only used on - `UOWTask` objects created within the "cyclical sort" part of the hierarchical - sort, which generates a dependency tree of individual instances instead of - mappers when cycles between mappers are detected. - """ + rec.update(listonly, isdelete) + + def _append_cyclical_childtask(self, task): + if "cyclical" not in self._objects: + self._objects["cyclical"] = UOWTaskElement(None) + self._objects["cyclical"].childtasks.append(task) - try: - rec = self._objects[obj] - retval = False - except KeyError: - rec = UOWTaskElement(obj) - self._objects[obj] = rec - retval = True - if not listonly: - rec.listonly = False - if childtask: - rec.childtasks.append(childtask) - if isdelete: - rec.isdelete = True - return retval - - def append_postupdate(self, obj, post_update_cols): + def append_postupdate(self, state, post_update_cols): """issue a 'post update' UPDATE statement via this object's mapper immediately. this operation is used only with relations that specify the `post_update=True` @@ -599,31 +579,22 @@ class UOWTask(object): # postupdates are UPDATED immeditely (for now) # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns # instead of __eq__ - self.mapper.save_obj([obj], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) - return True - - def delete(self, obj): - """remove the given object from this UOWTask, if present.""" - - try: - del self._objects[obj] - except KeyError: - pass + self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) - def __contains__(self, obj): + def __contains__(self, state): """return True if the given object is contained within this UOWTask or inheriting tasks.""" for task in self.polymorphic_tasks(): - if obj in task._objects: + if state in task._objects: return True else: return False - def is_deleted(self, obj): + def is_deleted(self, state): """return True if the given object is marked as to be deleted within this UOWTask.""" try: - return self._objects[obj].isdelete + return self._objects[state].isdelete except KeyError: return False @@ -647,20 +618,14 @@ class UOWTask(object): polymorphic_todelete_elements = property(lambda self:[rec for rec in self.polymorphic_elements if rec.isdelete]) - polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.polymorphic_elements - if rec.obj is not None and not rec.listonly and rec.isdelete is False]) + polymorphic_tosave_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is False]) - polymorphic_todelete_objects = property(lambda self:[rec.obj for rec in self.polymorphic_elements - if rec.obj is not None and not rec.listonly and rec.isdelete is True]) + polymorphic_todelete_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is True]) - dependencies = property(lambda self:self._dependencies) - - cyclical_dependencies = property(lambda self:self._cyclical_dependencies) - polymorphic_dependencies = _polymorphic_collection(lambda task:task.dependencies) - polymorphic_childtasks = _polymorphic_collection(lambda task:task.childtasks) - polymorphic_cyclical_dependencies = _polymorphic_collection(lambda task:task.cyclical_dependencies) def _sort_circular_dependencies(self, trans, cycles): @@ -675,29 +640,19 @@ class UOWTask(object): """ allobjects = [] for task in cycles: - allobjects += [e.obj for e in task.polymorphic_elements] + allobjects += [e.state for e in task.polymorphic_elements] tuples = [] cycles = util.Set(cycles) - #print "BEGIN CIRC SORT-------" - #print "PRE-CIRC:" - #print list(cycles) #[0].dump() - - # dependency processors that arent part of the cyclical thing - # get put here extradeplist = [] - - # organizes a set of new UOWTasks that will be assembled into - # the final tree, for the purposes of holding new UOWDependencyProcessors - # which process small sub-sections of dependent parent/child operations dependencies = {} - def get_dependency_task(obj, depprocessor): + def get_dependency_task(state, depprocessor): try: - dp = dependencies[obj] + dp = dependencies[state] except KeyError: - dp = dependencies.setdefault(obj, {}) + dp = dependencies.setdefault(state, {}) try: l = dp[depprocessor] except KeyError: @@ -706,8 +661,8 @@ class UOWTask(object): return l def dependency_in_cycles(dep): - proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper(), True) - targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper(), True) + proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True) + targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True) return targettask in cycles and (proctask is not None and proctask in cycles) # organize all original UOWDependencyProcessors by their target task @@ -725,23 +680,25 @@ class UOWTask(object): for task in cycles: for subtask in task.polymorphic_tasks(): for taskelement in subtask.elements: - obj = taskelement.obj - object_to_original_task[obj] = subtask + state = taskelement.state + object_to_original_task[state] = subtask for dep in deps_by_targettask.get(subtask, []): # is this dependency involved in one of the cycles ? - if not dependency_in_cycles(dep): + # (don't count the DetectKeySwitch prop) + if dep.processor.no_dependencies or not dependency_in_cycles(dep): continue (processor, targettask) = (dep.processor, dep.targettask) isdelete = taskelement.isdelete # list of dependent objects from this object - childlist = dep.get_object_dependencies(obj, trans, passive=True) - if childlist is None: + (added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True) + if not added and not unchanged and not deleted: continue + # the task corresponding to saving/deleting of those dependent objects childtask = trans.get_task_by_mapper(processor.mapper) - childlist = childlist.added_items() + childlist.unchanged_items() + childlist.deleted_items() + childlist = added + unchanged + deleted for o in childlist: # other object is None. this can occur if the relationship is many-to-one @@ -758,43 +715,40 @@ class UOWTask(object): object_to_original_task[o] = childtask # create a tuple representing the "parent/child" - whosdep = dep.whose_dependent_on_who(obj, o) + whosdep = dep.whose_dependent_on_who(state, o) if whosdep is not None: # append the tuple to the partial ordering. tuples.append(whosdep) # create a UOWDependencyProcessor representing this pair of objects. # append it to a UOWTask - if whosdep[0] is obj: + if whosdep[0] is state: get_dependency_task(whosdep[0], dep).append(whosdep[0], isdelete=isdelete) else: get_dependency_task(whosdep[0], dep).append(whosdep[1], isdelete=isdelete) else: - get_dependency_task(obj, dep).append(obj, isdelete=isdelete) + # TODO: no test coverage here + get_dependency_task(state, dep).append(state, isdelete=isdelete) - #print "TUPLES", tuples - #print "ALLOBJECTS", allobjects - head = topological.QueueDependencySorter(tuples, allobjects).sort() - - # create a tree of UOWTasks corresponding to the tree of object instances - # created by the DependencySorter + head = topological.sort_as_tree(tuples, allobjects) used_tasks = util.Set() def make_task_tree(node, parenttask, nexttasks): - originating_task = object_to_original_task[node.item] + (state, cycles, children) = node + originating_task = object_to_original_task[state] used_tasks.add(originating_task) t = nexttasks.get(originating_task, None) if t is None: t = UOWTask(self.uowtransaction, originating_task.mapper) nexttasks[originating_task] = t - parenttask.append(None, listonly=False, isdelete=originating_task._objects[node.item].isdelete, childtask=t) - t.append(node.item, originating_task._objects[node.item].listonly, isdelete=originating_task._objects[node.item].isdelete) + parenttask._append_cyclical_childtask(t) + t.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) - if dependencies.has_key(node.item): - for depprocessor, deptask in dependencies[node.item].iteritems(): + if state in dependencies: + for depprocessor, deptask in dependencies[state].iteritems(): t.cyclical_dependencies.add(depprocessor.branch(deptask)) nd = {} - for n in node.children: + for n in children: t2 = make_task_tree(n, t, nd) return t @@ -802,45 +756,30 @@ class UOWTask(object): # stick the non-circular dependencies onto the new UOWTask for d in extradeplist: - t._dependencies.add(d) + t.dependencies.add(d) - # if we have a head from the dependency sort, assemble child nodes - # onto the tree. note this only occurs if there were actual objects - # to be saved/deleted. if head is not None: make_task_tree(head, t, {}) + ret = [t] + + # add tasks that were in the cycle, but didnt get assembled + # into the cyclical tree, to the start of the list for t2 in cycles: - # tasks that were in the cycle but did not get assembled - # into the tree, add them as child tasks. these tasks - # will have no "save" or "delete" members, but may have dependency - # processors that operate upon other tasks outside of the cycle. if t2 not in used_tasks and t2 is not self: - # the task must be copied into a "cyclical" task, so that polymorphic - # rules dont fire off. this ensures that the task will have no "save" - # or "delete" members due to inheriting mappers which contain tasks localtask = UOWTask(self.uowtransaction, t2.mapper) - for obj in t2.elements: - localtask.append(obj, t2.listonly, isdelete=t2._objects[obj].isdelete) + for state in t2.elements: + localtask.append(state, t2.listonly, isdelete=t2._objects[state].isdelete) for dep in t2.dependencies: - localtask._dependencies.add(dep) - t.childtasks.insert(0, localtask) - - return t - - def dump(self): - """return a string representation of this UOWTask and its - full dependency graph.""" + localtask.dependencies.add(dep) + ret.insert(0, localtask) - buf = StringIO.StringIO() - import uowdumper - uowdumper.UOWDumper(self, buf) - return buf.getvalue() + return ret def __repr__(self): if self.mapper is not None: if self.mapper.__class__.__name__ == 'Mapper': - name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.name + name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.description else: name = repr(self.mapper) else: @@ -854,58 +793,36 @@ class UOWTaskElement(object): just part of the transaction as a placeholder for further dependencies (i.e. 'listonly'). - In the case of a ``UOWTaskElement`` present within an instance-level - graph formed due to cycles within the mapper-level graph, may also store a list of - childtasks, further UOWTasks containing objects dependent on this - element's object instance. + may also store additional sub-UOWTasks. """ - def __init__(self, obj): - self.obj = obj - self.__listonly = True + def __init__(self, state): + self.state = state + self.listonly = True self.childtasks = [] - self.__isdelete = False + self.isdelete = False self.__preprocessed = {} - def _get_listonly(self): - return self.__listonly - - def _set_listonly(self, value): - """Set_listonly is a one-way setter, will only go from True to False.""" - - if not value and self.__listonly: - self.__listonly = False - self.clear_preprocessed() - - def _get_isdelete(self): - return self.__isdelete - - def _set_isdelete(self, value): - if self.__isdelete is not value: - self.__isdelete = value - self.clear_preprocessed() - - listonly = property(_get_listonly, _set_listonly) - isdelete = property(_get_isdelete, _set_isdelete) + def update(self, listonly, isdelete): + if not listonly and self.listonly: + self.listonly = False + self.__preprocessed.clear() + if isdelete and not self.isdelete: + self.isdelete = True + self.__preprocessed.clear() def mark_preprocessed(self, processor): """Mark this element as *preprocessed* by a particular ``UOWDependencyProcessor``. - Preprocessing is the step which sweeps through all the - relationships on all the objects in the flush transaction and - adds other objects which are also affected. The actual logic is - part of ``UOWTransaction.execute()``. - - The preprocessing operations - are determined in part by the cascade rules indicated on a relationship, - and in part based on the normal semantics of relationships. - In some cases it can switch an object's state from *tosave* to *todelete*. - - Changes to the state of this ``UOWTaskElement`` will reset all - *preprocessed* flags, causing it to be preprocessed again. - When all ``UOWTaskElements have been fully preprocessed by all - UOWDependencyProcessors, then the topological sort can be - done. + Preprocessing is used by dependency.py to apply + flush-time cascade rules to relations and bring all + required objects into the flush context. + + each processor as marked as "processed" when complete, however + changes to the state of this UOWTaskElement will reset + the list of completed processors, so that they + execute again, until no new objects or state changes + are brought in. """ self.__preprocessed[processor] = True @@ -913,11 +830,8 @@ class UOWTaskElement(object): def is_preprocessed(self, processor): return self.__preprocessed.get(processor, False) - def clear_preprocessed(self): - self.__preprocessed.clear() - def __repr__(self): - return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.obj.__class__.__name__, id(self.obj), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) ) + return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.state.class_.__name__, id(self.state.obj()), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) ) class UOWDependencyProcessor(object): """In between the saving and deleting of objects, process @@ -960,16 +874,16 @@ class UOWDependencyProcessor(object): def getobj(elem): elem.mark_preprocessed(self) - return elem.obj + return elem.state ret = False - elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)] - if len(elements): + elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None and not elem.is_preprocessed(self)] + if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) - elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None and not elem.is_preprocessed(self)] - if len(elements): + elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None and not elem.is_preprocessed(self)] + if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) return ret @@ -978,14 +892,14 @@ class UOWDependencyProcessor(object): """process all objects contained within this ``UOWDependencyProcessor``s target task.""" if not delete: - self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None], trans, delete=False) + self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None], trans, delete=False) else: - self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None], trans, delete=True) + self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None], trans, delete=True) - def get_object_dependencies(self, obj, trans, passive): - return self.processor.get_object_dependencies(obj, trans, passive=passive) + def get_object_dependencies(self, state, trans, passive): + return trans.get_attribute_history(state, self.processor.key, passive=passive) - def whose_dependent_on_who(self, obj, o): + def whose_dependent_on_who(self, state1, state2): """establish which object is operationally dependent amongst a parent/child using the semantics stated by the dependency processor. @@ -994,7 +908,7 @@ class UOWDependencyProcessor(object): """ - return self.processor.whose_dependent_on_who(obj, o) + return self.processor.whose_dependent_on_who(state1, state2) def branch(self, task): """create a copy of this ``UOWDependencyProcessor`` against a new ``UOWTask`` object. @@ -1010,17 +924,19 @@ class UOWDependencyProcessor(object): class UOWExecutor(object): """Encapsulates the execution traversal of a UOWTransaction structure.""" - def execute(self, trans, task, isdelete=None): + def execute(self, trans, tasks, isdelete=None): if isdelete is not True: - self.execute_save_steps(trans, task) + for task in tasks: + self.execute_save_steps(trans, task) if isdelete is not False: - self.execute_delete_steps(trans, task) + for task in util.reversed(tasks): + self.execute_delete_steps(trans, task) def save_objects(self, trans, task): - task.mapper.save_obj(task.polymorphic_tosave_objects, trans) + task.mapper._save_obj(task.polymorphic_tosave_objects, trans) def delete_objects(self, trans, task): - task.mapper.delete_obj(task.polymorphic_todelete_objects, trans) + task.mapper._delete_obj(task.polymorphic_todelete_objects, trans) def execute_dependency(self, trans, dep, isdelete): dep.execute(trans, isdelete) @@ -1031,11 +947,9 @@ class UOWExecutor(object): self.execute_per_element_childtasks(trans, task, False) self.execute_dependencies(trans, task, False) self.execute_dependencies(trans, task, True) - self.execute_childtasks(trans, task, False) - + def execute_delete_steps(self, trans, task): self.execute_cyclical_dependencies(trans, task, True) - self.execute_childtasks(trans, task, True) self.execute_per_element_childtasks(trans, task, True) self.delete_objects(trans, task) @@ -1047,10 +961,6 @@ class UOWExecutor(object): for dep in util.reversed(list(task.polymorphic_dependencies)): self.execute_dependency(trans, dep, True) - def execute_childtasks(self, trans, task, isdelete=None): - for child in task.polymorphic_childtasks: - self.execute(trans, child, isdelete) - def execute_cyclical_dependencies(self, trans, task, isdelete): for dep in task.polymorphic_cyclical_dependencies: self.execute_dependency(trans, dep, isdelete) @@ -1061,8 +971,5 @@ class UOWExecutor(object): def execute_element_childtasks(self, trans, element, isdelete): for child in element.childtasks: - self.execute(trans, child, isdelete) + self.execute(trans, [child], isdelete) -# the AttributeManager used by the UOW/Session system to instrument -# object instances and track history. -attribute_manager = UOWAttributeManager() diff --git a/lib/sqlalchemy/orm/uowdumper.py b/lib/sqlalchemy/orm/uowdumper.py index 22b0ec8283..4b3fed70aa 100644 --- a/lib/sqlalchemy/orm/uowdumper.py +++ b/lib/sqlalchemy/orm/uowdumper.py @@ -1,5 +1,5 @@ # orm/uowdumper.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -8,58 +8,53 @@ from sqlalchemy.orm import unitofwork from sqlalchemy.orm import util as mapperutil +from sqlalchemy import util class UOWDumper(unitofwork.UOWExecutor): - def __init__(self, task, buf, verbose=False): + def __init__(self, tasks, buf, verbose=False): self.verbose = verbose self.indent = 0 - self.task = task + self.tasks = tasks self.buf = buf - self.starttask = task self.headers = {} - self.execute(None, task) + self.execute(None, tasks) - def execute(self, trans, task, isdelete=None): - oldstarttask = self.starttask - oldheaders = self.headers - self.starttask = task - self.headers = {} + def execute(self, trans, tasks, isdelete=None): + if isdelete is not True: + for task in tasks: + self._execute(trans, task, False) + if isdelete is not False: + for task in util.reversed(tasks): + self._execute(trans, task, True) + + def _execute(self, trans, task, isdelete): try: i = self._indent() - if len(i): - i += "-" - #i = i[0:-1] + "-" - self.buf.write(self._indent() + "\n") + if i: + i = i[:-1] + "+-" self.buf.write(i + " " + self._repr_task(task)) self.buf.write(" (" + (isdelete and "delete " or "save/update ") + "phase) \n") self.indent += 1 - super(UOWDumper, self).execute(trans, task, isdelete) + super(UOWDumper, self).execute(trans, [task], isdelete) finally: self.indent -= 1 - if self.starttask.is_empty(): - self.buf.write(self._indent() + " |- (empty task)\n") - else: - self.buf.write(self._indent() + " |----\n") - self.buf.write(self._indent() + "\n") - self.starttask = oldstarttask - self.headers = oldheaders def save_objects(self, trans, task): # sort elements to be inserted by insert order def comparator(a, b): - if a.obj is None: + if a.state is None: x = None - elif not hasattr(a.obj, '_sa_insert_order'): + elif not hasattr(a.state, 'insert_order'): x = None else: - x = a.obj._sa_insert_order - if b.obj is None: + x = a.state.insert_order + if b.state is None: y = None - elif not hasattr(b.obj, '_sa_insert_order'): + elif not hasattr(b.state, 'insert_order'): y = None else: - y = b.obj._sa_insert_order + y = b.state.insert_order return cmp(x, y) l = list(task.polymorphic_tosave_elements) @@ -68,7 +63,7 @@ class UOWDumper(unitofwork.UOWExecutor): if rec.listonly: continue self.header("Save elements"+ self._inheritance_tag(task)) - self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n") + self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n") self.closeheader() def delete_objects(self, trans, task): @@ -82,10 +77,8 @@ class UOWDumper(unitofwork.UOWExecutor): def _inheritance_tag(self, task): if not self.verbose: return "" - elif task is not self.starttask: - return (" (inheriting task %s)" % self._repr_task(task)) else: - return "" + return (" (inheriting task %s)" % self._repr_task(task)) def header(self, text): """Write a given header just once.""" @@ -115,11 +108,6 @@ class UOWDumper(unitofwork.UOWExecutor): def execute_dependencies(self, trans, task, isdelete=None): super(UOWDumper, self).execute_dependencies(trans, task, isdelete) - def execute_childtasks(self, trans, task, isdelete=None): - self.header("Child tasks" + self._inheritance_tag(task)) - super(UOWDumper, self).execute_childtasks(trans, task, isdelete) - self.closeheader() - def execute_cyclical_dependencies(self, trans, task, isdelete): self.header("Cyclical %s dependencies" % (isdelete and "delete" or "save")) super(UOWDumper, self).execute_cyclical_dependencies(trans, task, isdelete) @@ -140,14 +128,14 @@ class UOWDumper(unitofwork.UOWExecutor): val = proc.targettask.polymorphic_tosave_elements if self.verbose: - self.buf.write(self._indent() + " |- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % ( + self.buf.write(self._indent() + " +- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % ( repr(proc.processor.key), ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")), hex(id(proc)), self._repr_task(proc.targettask)) ) elif False: - self.buf.write(self._indent() + " |- %s attribute on %s\n" % ( + self.buf.write(self._indent() + " +- %s attribute on %s\n" % ( repr(proc.processor.key), ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")), ) @@ -155,18 +143,18 @@ class UOWDumper(unitofwork.UOWExecutor): if len(val) == 0: if self.verbose: - self.buf.write(self._indent() + " |- " + "(no objects)\n") + self.buf.write(self._indent() + " +- " + "(no objects)\n") for v in val: - self.buf.write(self._indent() + " |- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n") + self.buf.write(self._indent() + " +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n") def _repr_task_element(self, te, attribute=None, process=False): - if te.obj is None: + if getattr(te, 'state', None) is None: objid = "(placeholder)" else: if attribute is not None: - objid = "%s.%s" % (mapperutil.instance_str(te.obj), attribute) + objid = "%s.%s" % (mapperutil.state_str(te.state), attribute) else: - objid = mapperutil.instance_str(te.obj) + objid = mapperutil.state_str(te.state) if self.verbose: return "%s (UOWTaskElement(%s, %s))" % (objid, hex(id(te)), (te.listonly and 'listonly' or (te.isdelete and 'delete' or 'save'))) elif process: @@ -177,12 +165,16 @@ class UOWDumper(unitofwork.UOWExecutor): def _repr_task(self, task): if task.mapper is not None: if task.mapper.__class__.__name__ == 'Mapper': - name = task.mapper.class_.__name__ + "/" + task.mapper.local_table.name + "/" + str(task.mapper.entity_name) + name = task.mapper.class_.__name__ + "/" + task.mapper.local_table.description + "/" + str(task.mapper.entity_name) else: name = repr(task.mapper) else: name = '(none)' - return ("UOWTask(%s, %s)" % (hex(id(task)), name)) + sd = getattr(task, '_superduper', False) + if sd: + return ("SD UOWTask(%s, %s)" % (hex(id(task)), name)) + else: + return ("UOWTask(%s, %s)" % (hex(id(task)), name)) def _repr_task_class(self, task): if task.mapper is not None and task.mapper.__class__.__name__ == 'Mapper': diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index d248c0dd01..19e5e59b93 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1,11 +1,14 @@ # mapper/util.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, sql_util -from sqlalchemy.orm.interfaces import MapperExtension, EXT_PASS +from sqlalchemy import sql, util, exceptions +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql.util import row_adapter as create_row_adapter +from sqlalchemy.sql import visitors +from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator all_cascades = util.Set(["delete", "delete-orphan", "all", "merge", "expunge", "save-update", "refresh-expire", "none"]) @@ -75,45 +78,23 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'): result.append(sql.select([col(name, table) for name in colnames], from_obj=[table])) return sql.union_all(*result).alias(aliasname) -class TranslatingDict(dict): - """A dictionary that stores ``ColumnElement`` objects as keys. - Incoming ``ColumnElement`` keys are translated against those of an - underling ``FromClause`` for all operations. This way the columns - from any ``Selectable`` that is derived from or underlying this - ``TranslatingDict`` 's selectable can be used as keys. +class ExtensionCarrier(object): + """stores a collection of MapperExtension objects. + + allows an extension methods to be called on contained MapperExtensions + in the order they were added to this object. Also includes a 'methods' dictionary + accessor which allows for a quick check if a particular method + is overridden on any contained MapperExtensions. """ - - def __init__(self, selectable): - super(TranslatingDict, self).__init__() - self.selectable = selectable - - def __translate_col(self, col): - ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False) - if ourcol is None: - return col - else: - return ourcol - - def __getitem__(self, col): - return super(TranslatingDict, self).__getitem__(self.__translate_col(col)) - - def has_key(self, col): - return super(TranslatingDict, self).has_key(self.__translate_col(col)) - - def __setitem__(self, col, value): - return super(TranslatingDict, self).__setitem__(self.__translate_col(col), value) - - def __contains__(self, col): - return self.has_key(col) - - def setdefault(self, col, value): - return super(TranslatingDict, self).setdefault(self.__translate_col(col), value) - -class ExtensionCarrier(MapperExtension): + def __init__(self, _elements=None): - self.__elements = _elements or [] - + self.methods = {} + if _elements is not None: + self.__elements = [self.__inspect(e) for e in _elements] + else: + self.__elements = [] + def copy(self): return ExtensionCarrier(list(self.__elements)) @@ -123,186 +104,323 @@ class ExtensionCarrier(MapperExtension): def insert(self, extension): """Insert a MapperExtension at the beginning of this ExtensionCarrier's list.""" - self.__elements.insert(0, extension) + self.__elements.insert(0, self.__inspect(extension)) def append(self, extension): """Append a MapperExtension at the end of this ExtensionCarrier's list.""" - self.__elements.append(extension) + self.__elements.append(self.__inspect(extension)) - def _create_do(funcname): - def _do(self, *args, **kwargs): + def __inspect(self, extension): + for meth in MapperExtension.__dict__.keys(): + if meth not in self.methods and hasattr(extension, meth) and getattr(extension, meth) is not getattr(MapperExtension, meth): + self.methods[meth] = self.__create_do(meth) + return extension + + def __create_do(self, funcname): + def _do(*args, **kwargs): for elem in self.__elements: ret = getattr(elem, funcname)(*args, **kwargs) - if ret is not EXT_PASS: + if ret is not EXT_CONTINUE: return ret else: - return EXT_PASS - return _do + return EXT_CONTINUE - init_instance = _create_do('init_instance') - init_failed = _create_do('init_failed') - dispose_class = _create_do('dispose_class') - get_session = _create_do('get_session') - load = _create_do('load') - get = _create_do('get') - get_by = _create_do('get_by') - select_by = _create_do('select_by') - select = _create_do('select') - translate_row = _create_do('translate_row') - create_instance = _create_do('create_instance') - append_result = _create_do('append_result') - populate_instance = _create_do('populate_instance') - before_insert = _create_do('before_insert') - before_update = _create_do('before_update') - after_update = _create_do('after_update') - after_insert = _create_do('after_insert') - before_delete = _create_do('before_delete') - after_delete = _create_do('after_delete') - -class BinaryVisitor(sql.ClauseVisitor): - def __init__(self, func): - self.func = func - - def visit_binary(self, binary): - self.func(binary) + try: + _do.__name__ = funcname + except: + # cant set __name__ in py 2.3 + pass + return _do + + def _pass(self, *args, **kwargs): + return EXT_CONTINUE + + def __getattr__(self, key): + return self.methods.get(key, self._pass) class AliasedClauses(object): - """Creates aliases of a mapped tables for usage in ORM queries. - """ + """Creates aliases of a mapped tables for usage in ORM queries, and provides expression adaptation.""" - def __init__(self, mapped_table, alias=None): - if alias: - self.alias = alias - else: - self.alias = mapped_table.alias() - self.mapped_table = mapped_table - self.extra_cols = {} + def __init__(self, alias, equivalents=None, chain_to=None, should_adapt=True): + self.alias = alias + self.equivalents = equivalents self.row_decorator = self._create_row_adapter() - - def aliased_column(self, column): - """return the aliased version of the given column, creating a new label for it if not already - present in this AliasedClauses.""" + self.should_adapt = should_adapt + if should_adapt: + self.adapter = sql_util.ClauseAdapter(self.alias, equivalents=equivalents) + else: + self.adapter = visitors.NullVisitor() - conv = self.alias.corresponding_column(column, raiseerr=False) + if chain_to: + self.adapter.chain(chain_to.adapter) + + def aliased_column(self, column): + if not self.should_adapt: + return column + + conv = self.alias.corresponding_column(column) if conv: return conv + + # process column-level subqueries + aliased_column = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).traverse(column, clone=True) + + # anonymize labels which might have specific names + if isinstance(aliased_column, expression._Label): + aliased_column = aliased_column.label(None) - if column in self.extra_cols: - return self.extra_cols[column] - - aliased_column = column - # for column-level subqueries, swap out its selectable with our - # eager version as appropriate, and manually build the - # "correlation" list of the subquery. - class ModifySubquery(sql.ClauseVisitor): - def visit_select(s, select): - select._should_correlate = False - select.append_correlation(self.alias) - aliased_column = sql_util.ClauseAdapter(self.alias).chain(ModifySubquery()).traverse(aliased_column, clone=True) - aliased_column = aliased_column.label(None) - self.row_decorator.map[column] = aliased_column - # TODO: this is a little hacky - for attr in ('name', '_label'): - if hasattr(column, attr): - self.row_decorator.map[getattr(column, attr)] = aliased_column - self.extra_cols[column] = aliased_column + # add to row decorator explicitly + self.row_decorator({}).map[column] = aliased_column return aliased_column def adapt_clause(self, clause): - return self.aliased_column(clause) -# return sql_util.ClauseAdapter(self.alias).traverse(clause, clone=True) + return self.adapter.traverse(clause, clone=True) + def adapt_list(self, clauses): + return self.adapter.copy_and_process(clauses) + def _create_row_adapter(self): - """Return a callable which, - when passed a RowProxy, will return a new dict-like object - that translates Column objects to that of this object's Alias before calling upon the row. - - This allows a regular Table to be used to target columns in a row that was in reality generated from an alias - of that table, in such a way that the row can be passed to logic which knows nothing about the aliased form - of the table. - """ - class AliasedRowAdapter(object): - def __init__(self, row): - self.row = row - def __contains__(self, key): - return key in map or key in self.row - def has_key(self, key): - return key in self - def __getitem__(self, key): - if key in map: - key = map[key] - return self.row[key] - def keys(self): - return map.keys() - map = {} - for c in self.alias.c: - parent = self.mapped_table.corresponding_column(c) - map[parent] = c - map[parent._label] = c - map[parent.name] = c - for c in self.extra_cols: - map[c] = self.extra_cols[c] - # TODO: this is a little hacky - for attr in ('name', '_label'): - if hasattr(c, attr): - map[getattr(c, attr)] = self.extra_cols[c] - - AliasedRowAdapter.map = map - return AliasedRowAdapter + return create_row_adapter(self.alias, equivalent_columns=self.equivalents) + - class PropertyAliasedClauses(AliasedClauses): """extends AliasedClauses to add support for primary/secondary joins on a relation().""" - def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None): - super(PropertyAliasedClauses, self).__init__(prop.select_table) - + def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None, alias=None, should_adapt=True): + self.prop = prop + self.mapper = self.prop.mapper + self.table = self.prop.table self.parentclauses = parentclauses - if parentclauses is not None: - self.path = parentclauses.path + (prop.parent, prop.key) - else: - self.path = (prop.parent, prop.key) - self.prop = prop + if not alias: + from_obj = self.mapper._with_polymorphic_selectable() + alias = from_obj.alias() + + super(PropertyAliasedClauses, self).__init__(alias, equivalents=self.mapper._equivalent_columns, chain_to=parentclauses, should_adapt=should_adapt) if prop.secondary: self.secondary = prop.secondary.alias() + primary_aliasizer = sql_util.ClauseAdapter(self.secondary) + secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary)) + if parentclauses is not None: - aliasizer = sql_util.ClauseAdapter(self.alias).\ - chain(sql_util.ClauseAdapter(self.secondary)).\ - chain(sql_util.ClauseAdapter(parentclauses.alias)) - else: - aliasizer = sql_util.ClauseAdapter(self.alias).\ - chain(sql_util.ClauseAdapter(self.secondary)) - self.secondaryjoin = aliasizer.traverse(secondaryjoin, clone=True) - self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True) + primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, equivalents=parentclauses.equivalents)) + + self.secondaryjoin = secondary_aliasizer.traverse(secondaryjoin, clone=True) + self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True) else: + primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents) if parentclauses is not None: - aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side) - aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side)) - else: - aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side) - self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True) + primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side, equivalents=parentclauses.equivalents)) + + self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True) self.secondary = None self.secondaryjoin = None if prop.order_by: - self.order_by = sql_util.ClauseAdapter(self.alias).copy_and_process(util.to_list(prop.order_by)) + if prop.secondary: + # usually this is not used but occasionally someone has a sort key in their secondary + # table, even tho SA does not support writing this column directly + self.order_by = secondary_aliasizer.copy_and_process(util.to_list(prop.order_by)) + else: + self.order_by = primary_aliasizer.copy_and_process(util.to_list(prop.order_by)) + else: self.order_by = None + +class AliasedClass(object): + def __new__(cls, target): + from sqlalchemy.orm import attributes + mapper = _class_to_mapper(target) + alias = mapper.mapped_table.alias() + retcls = type(target.__name__ + "Alias", (cls,), {'alias':alias}) + retcls._class_state = mapper._class_state + for prop in mapper.iterate_properties: + existing = mapper._class_state.attrs[prop.key] + setattr(retcls, prop.key, attributes.InstrumentedAttribute(existing.impl, comparator=AliasedComparator(alias, existing.comparator))) + + return retcls + + def __init__(self, alias): + self.alias = alias + +class AliasedComparator(PropComparator): + def __init__(self, alias, comparator): + self.alias = alias + self.comparator = comparator + self.adapter = sql_util.ClauseAdapter(alias) + + def clause_element(self): + return self.adapter.traverse(self.comparator.clause_element(), clone=True) + + def operate(self, op, *other, **kwargs): + return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs), clone=True) + + def reverse_operate(self, op, other, **kwargs): + return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs), clone=True) + +from sqlalchemy.sql import expression +_selectable = expression._selectable +def _orm_selectable(selectable): + if _is_mapped_class(selectable): + if _is_aliased_class(selectable): + return selectable.alias + else: + return _class_to_mapper(selectable)._with_polymorphic_selectable() + else: + return _selectable(selectable) +expression._selectable = _orm_selectable + +class _ORMJoin(expression.Join): + """future functionality.""" + + __visit_name__ = expression.Join.__visit_name__ - mapper = property(lambda self:self.prop.mapper) - table = property(lambda self:self.prop.select_table) + def __init__(self, left, right, onclause=None, isouter=False): + if _is_mapped_class(left) or _is_mapped_class(right): + if hasattr(left, '_orm_mappers'): + left_mapper = left._orm_mappers[1] + adapt_from = left.right + else: + left_mapper = _class_to_mapper(left) + if _is_aliased_class(left): + adapt_from = left.alias + else: + adapt_from = None + + right_mapper = _class_to_mapper(right) + self._orm_mappers = (left_mapper, right_mapper) + + if isinstance(onclause, basestring): + prop = left_mapper.get_property(onclause) + + if _is_aliased_class(right): + adapt_to = right.alias + else: + adapt_to = None + + pj, sj, source, dest, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True) + + if sj: + left = sql.join(left, prop.secondary, onclause=pj) + onclause = sj + else: + onclause = pj + expression.Join.__init__(self, left, right, onclause, isouter) + + def join(self, right, onclause=None, isouter=False): + return _ORMJoin(self, right, onclause, isouter) + + def outerjoin(self, right, onclause=None): + return _ORMJoin(self, right, onclause, True) + +def _join(left, right, onclause=None): + """future functionality.""" + + return _ORMJoin(left, right, onclause, False) + +def _outerjoin(left, right, onclause=None): + """future functionality.""" + + return _ORMJoin(left, right, onclause, True) + +def has_identity(object): + return hasattr(object, '_instance_key') + +def _state_has_identity(state): + return '_instance_key' in state.dict + +def _is_mapped_class(cls): + return hasattr(cls, '_class_state') + +def _is_aliased_class(obj): + return isinstance(obj, type) and issubclass(obj, AliasedClass) - def __str__(self): - return "->".join([str(s) for s in self.path]) +def has_mapper(object): + """Return True if the given object has had a mapper association + set up, either through loading, or via insertion in a session. + """ + + return hasattr(object, '_entity_name') + +def _state_mapper(state, entity_name=None): + return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)] + +def object_mapper(object, entity_name=None, raiseerror=True): + """Given an object, return the primary Mapper associated with the object instance. + + object + The object instance. + entity_name + Entity name of the mapper to retrieve, if the given instance is + transient. Otherwise uses the entity name already associated + with the instance. + + raiseerror + Defaults to True: raise an ``InvalidRequestError`` if no mapper can + be located. If False, return None. + + """ + + try: + mapper = object.__class__._class_state.mappers[getattr(object, '_entity_name', entity_name)] + except (KeyError, AttributeError): + if raiseerror: + raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', entity_name))) + else: + return None + return mapper + +def class_mapper(class_, entity_name=None, compile=True, raiseerror=True): + """Given a class and optional entity_name, return the primary Mapper associated with the key. + + If no mapper can be located, raises ``InvalidRequestError``. + """ + + try: + mapper = class_._class_state.mappers[entity_name] + except (KeyError, AttributeError): + if raiseerror: + raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name)) + else: + return None + if compile: + return mapper.compile() + else: + return mapper + +def _class_to_mapper(class_or_mapper, entity_name=None, compile=True): + if isinstance(class_or_mapper, type): + return class_mapper(class_or_mapper, entity_name=entity_name, compile=compile) + else: + if compile: + return class_or_mapper.compile() + else: + return class_or_mapper def instance_str(instance): """Return a string describing an instance.""" return instance.__class__.__name__ + "@" + hex(id(instance)) +def state_str(state): + """Return a string describing an instance.""" + if state is None: + return "None" + else: + return state.class_.__name__ + "@" + hex(id(state.obj())) + def attribute_str(instance, attribute): return instance_str(instance) + "." + attribute + +def identity_equal(a, b): + if a is b: + return True + id_a = getattr(a, '_instance_key', None) + id_b = getattr(b, '_instance_key', None) + if id_a is None or id_b is None: + return False + return id_a == id_b + diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index f86e14ab1e..31adf77d12 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -1,51 +1,45 @@ # pool.py - Connection pooling for SQLAlchemy -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Provide a connection pool implementation, which optionally manages -connections on a thread local basis. +"""Connection pooling for DB-API connections. -Also provides a DBAPI2 transparency layer so that pools can be managed -automatically, based on module type and connect arguments, simply by -calling regular DBAPI connect() methods. +Provides a number of connection pool implementations for a variety of +usage scenarios and thread behavior requirements imposed by the +application, DB-API or database itself. + +Also provides a DB-API 2.0 connection proxying mechanism allowing +regular DB-API connect() methods to be transparently managed by a +SQLAlchemy connection pool. """ import weakref, time -try: - import cPickle as pickle -except: - import pickle from sqlalchemy import exceptions, logging from sqlalchemy import queue as Queue - -try: - import thread, threading -except: - import dummy_thread as thread - import dummy_threading as threading +from sqlalchemy.util import thread, threading, pickle, as_interface proxies = {} def manage(module, **params): - """Return a proxy for module that automatically pools connections. + """Return a proxy for a DB-API module that automatically pools connections. - Given a DBAPI2 module and pool management parameters, returns a - proxy for the module that will automatically pool connections, + Given a DB-API 2.0 module and pool management parameters, returns + a proxy for the module that will automatically pool connections, creating new connection pools for each distinct set of connection arguments sent to the decorated module's connect() function. Arguments: module - A DBAPI2 database module. + A DB-API 2.0 database module. poolclass - The class used by the pool module to provide pooling. - Defaults to ``QueuePool``. + The class used by the pool module to provide pooling. Defaults + to ``QueuePool``. See the ``Pool`` class for options. """ @@ -55,7 +49,7 @@ def manage(module, **params): return proxies.setdefault(module, _DBProxy(module, **params)) def clear_managers(): - """Remove all current DBAPI2 managers. + """Remove all current DB-API 2.0 managers. All pools and connections are disposed. """ @@ -65,10 +59,10 @@ def clear_managers(): proxies.clear() class Pool(object): - """Base Pool class. + """Base class for connection pools. - This is an abstract class, which is implemented by various - subclasses including: + This is an abstract class, implemented by various subclasses + including: QueuePool Pools multiple connections using ``Queue.Queue``. @@ -84,7 +78,7 @@ class Pool(object): is checked out at a time. The main argument, `creator`, is a callable function that returns - a newly connected DBAPI connection object. + a newly connected DB-API connection object. Options that are understood by Pool are: @@ -111,65 +105,76 @@ class Pool(object): surpassed the connection will be closed and replaced with a newly opened connection. Defaults to -1. - auto_close_cursors - Cursors, returned by ``connection.cursor()``, are tracked and - are automatically closed when the connection is returned to the - pool. Some DBAPIs like MySQLDB become unstable if cursors - remain open. Defaults to True. + listeners + A list of ``PoolListener``-like objects or dictionaries of callables + that receive events when DB-API connections are created, checked out and + checked in to the pool. - disallow_open_cursors - If `auto_close_cursors` is False, and `disallow_open_cursors` is - True, will raise an exception if an open cursor is detected upon - connection checkin. Defaults to False. + reset_on_return + Defaults to True. Reset the database state of connections returned to + the pool. This is typically a ROLLBACK to release locks and transaction + resources. Disable at your own peril. - If `auto_close_cursors` and `disallow_open_cursors` are both - False, then no cursor processing occurs upon checkin. """ - - def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, auto_close_cursors=True, - disallow_open_cursors=False): - self.logger = logging.instance_logger(self) + def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=True, + reset_on_return=True, listeners=None): + self.logger = logging.instance_logger(self, echoflag=echo) + # the WeakValueDictionary works more nicely than a regular dict of + # weakrefs. the latter can pile up dead reference objects which don't + # get cleaned out. WVD adds from 1-6 method calls to a checkout + # operation. self._threadconns = weakref.WeakValueDictionary() self._creator = creator self._recycle = recycle self._use_threadlocal = use_threadlocal - self.auto_close_cursors = auto_close_cursors - self.disallow_open_cursors = disallow_open_cursors + self._reset_on_return = reset_on_return self.echo = echo - echo = logging.echo_property() + self.listeners = [] + self._on_connect = [] + self._on_checkout = [] + self._on_checkin = [] + + if listeners: + for l in listeners: + self.add_listener(l) def unique_connection(self): return _ConnectionFairy(self).checkout() def create_connection(self): return _ConnectionRecord(self) - + def recreate(self): - """return a new instance of this Pool's class with identical creation arguments.""" + """Return a new instance with identical creation arguments.""" + raise NotImplementedError() def dispose(self): - """dispose of this pool. - - this method leaves the possibility of checked-out connections remaining opened, - so it is advised to not reuse the pool once dispose() is called, and to instead - use a new pool constructed by the recreate() method. + """Dispose of this pool. + + This method leaves the possibility of checked-out connections + remaining open, It is advised to not reuse the pool once dispose() + is called, and to instead use a new pool constructed by the + recreate() method. """ + raise NotImplementedError() - + def connect(self): if not self._use_threadlocal: return _ConnectionFairy(self).checkout() try: - return self._threadconns[thread.get_ident()].connfairy().checkout() + return self._threadconns[thread.get_ident()].checkout() except KeyError: - agent = _ConnectionFairy(self).checkout() - self._threadconns[thread.get_ident()] = agent._threadfairy - return agent + agent = _ConnectionFairy(self) + self._threadconns[thread.get_ident()] = agent + return agent.checkout() - def return_conn(self, agent): - self.do_return_conn(agent._connection_record) + def return_conn(self, record): + if self._use_threadlocal and thread.get_ident() in self._threadconns: + del self._threadconns[thread.get_ident()] + self.do_return_conn(record) def get(self): return self.do_get() @@ -183,6 +188,26 @@ class Pool(object): def status(self): raise NotImplementedError() + def add_listener(self, listener): + """Add a ``PoolListener``-like object to this pool. + + ``listener`` may be an object that implements some or all of + PoolListener, or a dictionary of callables containing implementations + of some or all of the named methods in PoolListener. + + """ + + listener = as_interface( + listener, methods=('connect', 'checkout', 'checkin')) + + self.listeners.append(listener) + if hasattr(listener, 'connect'): + self._on_connect.append(listener) + if hasattr(listener, 'checkout'): + self._on_checkout.append(listener) + if hasattr(listener, 'checkin'): + self._on_checkin.append(listener) + def log(self, msg): self.logger.info(msg) @@ -190,104 +215,154 @@ class _ConnectionRecord(object): def __init__(self, pool): self.__pool = pool self.connection = self.__connect() - self.properties = {} + self.info = {} + if pool._on_connect: + for l in pool._on_connect: + l.connect(self.connection, self) def close(self): if self.connection is not None: - self.__pool.log("Closing connection %s" % repr(self.connection)) - self.connection.close() + if self.__pool._should_log_info: + self.__pool.log("Closing connection %r" % self.connection) + try: + self.connection.close() + except (SystemExit, KeyboardInterrupt): + raise + except: + if self.__pool._should_log_info: + self.__pool.log("Exception closing connection %r" % + self.connection) def invalidate(self, e=None): - if e is not None: - self.__pool.log("Invalidate connection %s (reason: %s:%s)" % (repr(self.connection), e.__class__.__name__, str(e))) - else: - self.__pool.log("Invalidate connection %s" % repr(self.connection)) + if self.__pool._should_log_info: + if e is not None: + self.__pool.log("Invalidate connection %r (reason: %s:%s)" % + (self.connection, e.__class__.__name__, e)) + else: + self.__pool.log("Invalidate connection %r" % self.connection) self.__close() self.connection = None def get_connection(self): if self.connection is None: self.connection = self.__connect() - self.properties.clear() + self.info.clear() + if self.__pool._on_connect: + for l in self.__pool._on_connect: + l.connect(self.connection, self) elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle): - self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection)) + if self.__pool._should_log_info: + self.__pool.log("Connection %r exceeded timeout; recycling" % + self.connection) self.__close() self.connection = self.__connect() - self.properties.clear() + self.info.clear() + if self.__pool._on_connect: + for l in self.__pool._on_connect: + l.connect(self.connection, self) return self.connection def __close(self): try: - self.__pool.log("Closing connection %s" % (repr(self.connection))) + if self.__pool._should_log_info: + self.__pool.log("Closing connection %r" % self.connection) self.connection.close() except Exception, e: - self.__pool.log("Connection %s threw an error on close: %s" % (repr(self.connection), str(e))) + if self.__pool._should_log_info: + self.__pool.log("Connection %r threw an error on close: %s" % + (self.connection, e)) + if isinstance(e, (SystemExit, KeyboardInterrupt)): + raise def __connect(self): try: self.starttime = time.time() connection = self.__pool._creator() - self.__pool.log("Created new connection %s" % repr(connection)) + if self.__pool._should_log_info: + self.__pool.log("Created new connection %r" % connection) return connection except Exception, e: - self.__pool.log("Error on connect(): %s" % (str(e))) + if self.__pool._should_log_info: + self.__pool.log("Error on connect(): %s" % e) raise -class _ThreadFairy(object): - """Mark a thread identifier as owning a connection, for a thread local pool.""" + properties = property(lambda self: self.info, + doc="A synonym for .info, will be removed in 0.5.") - def __init__(self, connfairy): - self.connfairy = weakref.ref(connfairy) +def _finalize_fairy(connection, connection_record, pool, ref=None): + if ref is not None and connection_record.backref is not ref: + return + if connection is not None: + try: + if pool._reset_on_return: + connection.rollback() + # Immediately close detached instances + if connection_record is None: + connection.close() + except Exception, e: + if connection_record is not None: + connection_record.invalidate(e=e) + if isinstance(e, (SystemExit, KeyboardInterrupt)): + raise + if connection_record is not None: + connection_record.backref = None + if pool._should_log_info: + pool.log("Connection %r being returned to pool" % connection) + if pool._on_checkin: + for l in pool._on_checkin: + l.checkin(connection, connection_record) + pool.return_conn(connection_record) class _ConnectionFairy(object): - """Proxy a DBAPI connection object and provides return-on-dereference support.""" + """Proxies a DB-API connection and provides return-on-dereference support.""" def __init__(self, pool): - self._threadfairy = _ThreadFairy(self) - self._cursors = weakref.WeakKeyDictionary() self._pool = pool self.__counter = 0 try: - self._connection_record = pool.get() - self.connection = self._connection_record.get_connection() + rec = self._connection_record = pool.get() + conn = self.connection = self._connection_record.get_connection() + self._connection_record.backref = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref)) except: self.connection = None # helps with endless __getattr__ loops later on self._connection_record = None raise - if self._pool.echo: - self._pool.log("Connection %s checked out from pool" % repr(self.connection)) - + if self._pool._should_log_info: + self._pool.log("Connection %r checked out from pool" % + self.connection) + _logger = property(lambda self: self._pool.logger) - + is_valid = property(lambda self:self.connection is not None) - def _get_properties(self): - """A property collection unique to this DBAPI connection.""" - + def _get_info(self): + """An info collection unique to this DB-API connection.""" + try: - return self._connection_record.properties + return self._connection_record.info except AttributeError: if self.connection is None: raise exceptions.InvalidRequestError("This connection is closed") try: - return self._detatched_properties + return self._detached_info except AttributeError: - self._detatched_properties = value = {} + self._detached_info = value = {} return value - properties = property(_get_properties) - + info = property(_get_info) + properties = property(_get_info) + def invalidate(self, e=None): """Mark this connection as invalidated. - - The connection will be immediately closed. The - containing ConnectionRecord will create a new connection when next used. + + The connection will be immediately closed. The containing + ConnectionRecord will create a new connection when next used. """ + if self.connection is None: raise exceptions.InvalidRequestError("This connection is closed") if self._connection_record is not None: self._connection_record.invalidate(e=e) self.connection = None - self._cursors = None self._close() def cursor(self, *args, **kwargs): @@ -305,95 +380,98 @@ class _ConnectionFairy(object): if self.connection is None: raise exceptions.InvalidRequestError("This connection is closed") self.__counter +=1 - return self + + if not self._pool._on_checkout or self.__counter != 1: + return self + + # Pool listeners can trigger a reconnection on checkout + attempts = 2 + while attempts > 0: + try: + for l in self._pool._on_checkout: + l.checkout(self.connection, self._connection_record, self) + return self + except exceptions.DisconnectionError, e: + if self._pool._should_log_info: + self._pool.log( + "Disconnection detected on checkout: %s" % e) + self._connection_record.invalidate(e) + self.connection = self._connection_record.get_connection() + attempts -= 1 + + if self._pool._should_log_info: + self._pool.log("Reconnection attempts exhausted on checkout") + self.invalidate() + raise exceptions.InvalidRequestError("This connection is closed") def detach(self): - """Separate this Connection from its Pool. - - This means that the connection will no longer be returned to the - pool when closed, and will instead be literally closed. The - containing ConnectionRecord is separated from the DBAPI connection, and - will create a new connection when next used. + """Separate this connection from its Pool. + + This means that the connection will no longer be returned to the + pool when closed, and will instead be literally closed. The + containing ConnectionRecord is separated from the DB-API connection, + and will create a new connection when next used. + + Note that any overall connection limiting constraints imposed by a + Pool implementation may be violated after a detach, as the detached + connection is removed from the pool's knowledge and control. """ - + if self._connection_record is not None: - self._connection_record.connection = None + self._connection_record.connection = None + self._connection_record.backref = None self._pool.do_return_conn(self._connection_record) - self._detatched_properties = \ - self._connection_record.properties.copy() + self._detached_info = \ + self._connection_record.info.copy() self._connection_record = None - def close_open_cursors(self): - if self._cursors is not None: - for c in list(self._cursors): - c.close() - def close(self): self.__counter -=1 if self.__counter == 0: self._close() - def __del__(self): - self._close() - def _close(self): - if self._cursors is not None: - # cursors should be closed before connection is returned to the pool. some dbapis like - # mysql have real issues if they are not. - if self._pool.auto_close_cursors: - self.close_open_cursors() - elif self._pool.disallow_open_cursors: - if len(self._cursors): - raise exceptions.InvalidRequestError("This connection still has %d open cursors" % len(self._cursors)) - if self.connection is not None: - try: - self.connection.rollback() - # Immediately close detached instances - if self._connection_record is None: - self.connection.close() - except Exception, e: - if self._connection_record is not None: - self._connection_record.invalidate(e=e) - if self._connection_record is not None: - if self._pool.echo: - self._pool.log("Connection %s being returned to pool" % repr(self.connection)) - self._pool.return_conn(self) + _finalize_fairy(self.connection, self._connection_record, self._pool) self.connection = None self._connection_record = None - self._threadfairy = None - self._cursors = None class _CursorFairy(object): def __init__(self, parent, cursor): self.__parent = parent - self.__parent._cursors[self] = True self.cursor = cursor def invalidate(self, e=None): self.__parent.invalidate(e=e) - + def close(self): - if self in self.__parent._cursors: - del self.__parent._cursors[self] + try: + self.cursor.close() + except Exception, e: try: - self.cursor.close() - except Exception, e: - self.__parent._logger.warn("Error closing cursor: " + str(e)) + ex_text = str(e) + except TypeError: + ex_text = repr(e) + self.__parent._logger.warn("Error closing cursor: " + ex_text) + + if isinstance(e, (SystemExit, KeyboardInterrupt)): + raise def __getattr__(self, key): return getattr(self.cursor, key) class SingletonThreadPool(Pool): - """Maintain one connection per each thread, never moving a - connection to a thread other than the one which it was created in. + """A Pool that maintains one connection per thread. - This is used for SQLite, which both does not handle multithreading - by default, and also requires a singleton connection if a :memory: - database is being used. + Maintains one connection per each thread, never moving a connection to a + thread other than the one which it was created in. + + This is used for SQLite, which both does not handle multithreading by + default, and also requires a singleton connection if a :memory: database + is being used. Options are the same as those of Pool, as well as: - pool_size : 5 + pool_size: 5 The number of threads in which to maintain connections at once. """ @@ -405,20 +483,25 @@ class SingletonThreadPool(Pool): def recreate(self): self.log("Pool recreating") - return SingletonThreadPool(self._creator, pool_size=self.size, recycle=self._recycle, echo=self.echo, use_threadlocal=self._use_threadlocal, auto_close_cursors=self.auto_close_cursors, disallow_open_cursors=self.disallow_open_cursors) - + return SingletonThreadPool(self._creator, pool_size=self.size, recycle=self._recycle, echo=self._should_log_info, use_threadlocal=self._use_threadlocal, listeners=self.listeners) + def dispose(self): - """dispose of this pool. - - this method leaves the possibility of checked-out connections remaining opened, - so it is advised to not reuse the pool once dispose() is called, and to instead - use a new pool constructed by the recreate() method. + """Dispose of this pool. + + this method leaves the possibility of checked-out connections + remaining opened, so it is advised to not reuse the pool once + dispose() is called, and to instead use a new pool constructed + by the recreate() method. """ + for key, conn in self._conns.items(): try: conn.close() + except (SystemExit, KeyboardInterrupt): + raise except: - # sqlite won't even let you close a conn from a thread that didn't create it + # sqlite won't even let you close a conn from a thread + # that didn't create it pass del self._conns[key] @@ -454,7 +537,7 @@ class SingletonThreadPool(Pool): return c class QueuePool(Pool): - """Use ``Queue.Queue`` to maintain a fixed-size list of connections. + """A Pool that imposes a limit on the number of open connections. Arguments include all those used by the base Pool class, as well as: @@ -494,7 +577,7 @@ class QueuePool(Pool): def recreate(self): self.log("Pool recreating") - return QueuePool(self._creator, pool_size=self._pool.maxsize, max_overflow=self._max_overflow, timeout=self._timeout, recycle=self._recycle, echo=self.echo, use_threadlocal=self._use_threadlocal, auto_close_cursors=self.auto_close_cursors, disallow_open_cursors=self.disallow_open_cursors) + return QueuePool(self._creator, pool_size=self._pool.maxsize, max_overflow=self._max_overflow, timeout=self._timeout, recycle=self._recycle, echo=self._should_log_info, use_threadlocal=self._use_threadlocal, listeners=self.listeners) def do_return_conn(self, conn): try: @@ -545,7 +628,8 @@ class QueuePool(Pool): break self._overflow = 0 - self.size() - self.log("Pool disposed. " + self.status()) + if self._should_log_info: + self.log("Pool disposed. " + self.status()) def status(self): tup = (self.size(), self.checkedin(), self.overflow(), self.checkedout()) @@ -564,10 +648,10 @@ class QueuePool(Pool): return self._pool.maxsize - self._pool.qsize() + self._overflow class NullPool(Pool): - """A Pool implementation which does not pool connections. + """A Pool which does not pool connections. - Instead it literally opens and closes the underlying DBAPI - connection per each connection open/close. + Instead it literally opens and closes the underlying DB-API connection + per each connection open/close. """ def status(self): @@ -583,8 +667,7 @@ class NullPool(Pool): return self.create_connection() class StaticPool(Pool): - """A Pool implementation which stores exactly one connection that is - returned for all requests.""" + """A Pool of exactly one connection, used for all requests.""" def __init__(self, creator, **params): Pool.__init__(self, creator, **params) @@ -594,6 +677,10 @@ class StaticPool(Pool): def status(self): return "StaticPool" + def dispose(self): + self._conn.close() + self._conn = None + def create_connection(self): return self._conn @@ -605,15 +692,14 @@ class StaticPool(Pool): def do_get(self): return self.connection - - + + class AssertionPool(Pool): - """A Pool implementation that allows at most one checked out - connection at a time. + """A Pool that allows at most one checked out connection at any given time. - This will raise an exception if more than one connection is - checked out at a time. Useful for debugging code that is using - more connections than desired. + This will raise an exception if more than one connection is checked out + at a time. Useful for debugging code that is using more connections + than desired. """ ## TODO: modify this to handle an arbitrary connection count. @@ -643,19 +729,21 @@ class AssertionPool(Pool): return c class _DBProxy(object): - """Proxy a DBAPI2 connect() call to a pooled connection keyed to - the specific connect parameters. Other attributes are proxied - through via __getattr__. + """Layers connection pooling behavior on top of a standard DB-API module. + + Proxies a DB-API 2.0 connect() call to a connection pool keyed to the + specific connect parameters. Other functions and attributes are delegated + to the underlying DB-API module. """ - def __init__(self, module, poolclass = QueuePool, **params): - """Initialize a new proxy. + def __init__(self, module, poolclass=QueuePool, **params): + """Initializes a new proxy. module - a DBAPI2 module. + a DB-API 2.0 module poolclass - a Pool class, defaulting to QueuePool. + a Pool class, defaulting to QueuePool Other parameters are sent to the Pool object's constructor. """ @@ -687,16 +775,14 @@ class _DBProxy(object): def connect(self, *args, **params): """Activate a connection to the database. - Connect to the database using this DBProxy's module and the - given connect arguments. If the arguments match an existing - pool, the connection will be returned from the pool's current - thread-local connection instance, or if there is no - thread-local connection instance it will be checked out from - the set of pooled connections. + Connect to the database using this DBProxy's module and the given + connect arguments. If the arguments match an existing pool, the + connection will be returned from the pool's current thread-local + connection instance, or if there is no thread-local connection + instance it will be checked out from the set of pooled connections. - If the pool has no available connections and allows new - connections to be created, a new database connection will be - made. + If the pool has no available connections and allows new connections + to be created, a new database connection will be made. """ return self.get_pool(*args, **params).connect() diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 3faa3b89c7..9a4bf4109c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -1,34 +1,47 @@ # schema.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """The schema module provides the building blocks for database metadata. -This means all the entities within a SQL database that we might want -to look at, modify, or create and delete are described by these -objects, in a database-agnostic way. +Each element within this module describes a database entity which can be +created and dropped, or is otherwise part of such an entity. Examples include +tables, columns, sequences, and indexes. -A structure of SchemaItems also provides a *visitor* interface which is -the primary method by which other methods operate upon the schema. -The SQL package extends this structure with its own clause-specific -objects as well as the visitor interface, so that the schema package -*plugs in* to the SQL package. -""" +All entities are subclasses of [sqlalchemy.schema#SchemaItem], and as defined +in this module they are intended to be agnostic of any vendor-specific +constructs. -from sqlalchemy import sql, types, exceptions,util, databases -import sqlalchemy -import re, string, inspect +A collection of entities are grouped into a unit called +[sqlalchemy.schema#MetaData]. MetaData serves as a logical grouping of schema +elements, and can also be associated with an actual database connection such +that operations involving the contained elements can contact the database as +needed. -__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint', - 'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'DefaultGenerator', 'Constraint', - 'MetaData', 'ThreadLocalMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] +Two of the elements here also build upon their "syntactic" counterparts, which +are defined in [sqlalchemy.sql.expression#], specifically +[sqlalchemy.schema#Table] and [sqlalchemy.schema#Column]. Since these objects +are part of the SQL expression language, they are usable as components in SQL +expressions. """ + +import re, inspect +from sqlalchemy import types, exceptions, util, databases +from sqlalchemy.sql import expression, visitors + +URL = None + +__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', + 'ForeignKeyConstraint', 'PrimaryKeyConstraint', 'CheckConstraint', + 'UniqueConstraint', 'DefaultGenerator', 'Constraint', 'MetaData', + 'ThreadLocalMetaData', 'SchemaVisitor', 'PassiveDefault', + 'ColumnDefault', 'DDL'] class SchemaItem(object): """Base class for items that define a database schema.""" - __metaclass__ = sql._FigureVisitName + __metaclass__ = expression._FigureVisitName def _init_items(self, *args): """Initialize the list of child items for this SchemaItem.""" @@ -37,303 +50,250 @@ class SchemaItem(object): if item is not None: item._set_parent(self) - def _get_parent(self): - raise NotImplementedError() - def _set_parent(self, parent): """Associate with this SchemaItem's parent object.""" raise NotImplementedError() - + def get_children(self, **kwargs): """used to allow SchemaVisitor access""" return [] - + def __repr__(self): return "%s()" % self.__class__.__name__ - def _derived_metadata(self): - """Return the the MetaData to which this item is bound.""" - - return None - - def _get_engine(self, raiseerr=False): - """Return the engine or None if no engine.""" - - if raiseerr: - m = self._derived_metadata() - e = m and m.bind or None - if e is None: - raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") - else: - return e - else: - m = self._derived_metadata() - return m and m.bind or None - - def _set_casing_strategy(self, kwargs, keyname='case_sensitive'): - """Set the "case_sensitive" argument sent via keywords to the item's constructor. - - For the purposes of Table's 'schema' property, the name of the - variable is optionally configurable. - """ - setattr(self, '_%s_setting' % keyname, kwargs.pop(keyname, None)) - - def _determine_case_sensitive(self, keyname='case_sensitive'): - """Determine the `case_sensitive` value for this item. + def bind(self): + """Return the connectable associated with this SchemaItem.""" - For the purposes of Table's `schema` property, the name of the - variable is optionally configurable. + m = self.metadata + return m and m.bind or None + bind = property(bind) - A local non-None value overrides all others. After that, the - parent item (i.e. ``Column`` for a ``Sequence``, ``Table`` for - a ``Column``, ``MetaData`` for a ``Table``) is searched for a - non-None setting, traversing each parent until none are found. - finally, case_sensitive is set to True as a default. - """ - - local = getattr(self, '_%s_setting' % keyname, None) - if local is not None: - return local - parent = self - while parent is not None: - parent = parent._get_parent() - if parent is not None: - parentval = getattr(parent, '_case_sensitive_setting', None) - if parentval is not None: - return parentval - return True - - def _get_case_sensitive(self): - """late-compile the 'case-sensitive' setting when first accessed. - - typically the SchemaItem will be assembled into its final structure - of other SchemaItems at this point, whereby it can attain this setting - from its containing SchemaItem if not defined locally. - """ - + def info(self): try: - return self.__case_sensitive + return self._info except AttributeError: - self.__case_sensitive = self._determine_case_sensitive() - return self.__case_sensitive - case_sensitive = property(_get_case_sensitive) + self._info = {} + return self._info + info = property(info) + - metadata = property(lambda s:s._derived_metadata()) - bind = property(lambda s:s._get_engine()) - def _get_table_key(name, schema): if schema is None: return name else: return schema + "." + name -class _TableSingleton(sql._FigureVisitName): +class _TableSingleton(expression._FigureVisitName): """A metaclass used by the ``Table`` object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): - schema = kwargs.get('schema', None) - autoload = kwargs.pop('autoload', False) - autoload_with = kwargs.pop('autoload_with', False) - mustexist = kwargs.pop('mustexist', False) + schema = kwargs.get('schema', kwargs.get('owner', None)) useexisting = kwargs.pop('useexisting', False) - include_columns = kwargs.pop('include_columns', None) + mustexist = kwargs.pop('mustexist', False) key = _get_table_key(name, schema) try: table = metadata.tables[key] - if len(args): - if not useexisting: - raise exceptions.ArgumentError("Table '%s' is already defined for this MetaData instance." % key) + if not useexisting and table._cant_override(*args, **kwargs): + raise exceptions.InvalidRequestError( + "Table '%s' is already defined for this MetaData instance. " + "Specify 'useexisting=True' to redefine options and " + "columns on an existing Table object." % key) + else: + table._init_existing(*args, **kwargs) return table except KeyError: if mustexist: - raise exceptions.ArgumentError("Table '%s.%s' not defined" % (schema, name)) - table = type.__call__(self, name, metadata, **kwargs) - table._set_parent(metadata) - # load column definitions from the database if 'autoload' is defined - # we do it after the table is in the singleton dictionary to support - # circular foreign keys - if autoload: - try: - if autoload_with: - autoload_with.reflecttable(table, include_columns=include_columns) - else: - metadata._get_engine(raiseerr=True).reflecttable(table, include_columns=include_columns) - except exceptions.NoSuchTableError: + raise exceptions.InvalidRequestError( + "Table '%s' not defined" % (key)) + try: + return type.__call__(self, name, metadata, *args, **kwargs) + except: + if key in metadata.tables: del metadata.tables[key] - raise - # initialize all the column, etc. objects. done after - # reflection to allow user-overrides - table._init_items(*args) - return table + raise -class Table(SchemaItem, sql.TableClause): - """Represent a relational database table. - - This subclasses ``sql.TableClause`` to provide a table that is - associated with an instance of ``MetaData``, which in turn - may be associated with an instance of ``Engine``. - - Whereas ``TableClause`` represents a table as its used in an SQL - expression, ``Table`` represents a table as it exists in a - database schema. - - If this ``Table`` is ultimately associated with an engine, - the ``Table`` gains the ability to access the database directly - without the need for dealing with an explicit ``Connection`` object; - this is known as "implicit execution". - - Implicit operation allows the ``Table`` to access the database to - reflect its own properties (via the autoload=True flag), it allows - the create() and drop() methods to be called without passing - a connectable, and it also propigates the underlying engine to - constructed SQL objects so that they too can be executed via their - execute() method without the need for a ``Connection``. - """ +class Table(SchemaItem, expression.TableClause): + """Represent a relational database table.""" __metaclass__ = _TableSingleton - def __init__(self, name, metadata, **kwargs): + ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop') + + def __init__(self, name, metadata, *args, **kwargs): """Construct a Table. - Table objects can be constructed directly. The init method is - actually called via the TableSingleton metaclass. Arguments - are: + Table objects can be constructed directly. Arguments are: name - The name of this table, exactly as it appears, or will - appear, in the database. + The name of this table, exactly as it appears, or will appear, in + the database. - This property, along with the *schema*, indicates the - *singleton identity* of this table. + This property, along with the *schema*, indicates the *singleton + identity* of this table. - Further tables constructed with the same name/schema - combination will return the same Table instance. + Further tables constructed with the same name/schema combination + will return the same Table instance. \*args Should contain a listing of the Column objects for this table. \**kwargs - options include: + kwargs include: schema - The *schema name* for this table, which is - required if the table resides in a schema other than the - default selected schema for the engine's database - connection. Defaults to ``None``. + The *schema name* for this table, which is required if the table + resides in a schema other than the default selected schema for the + engine's database connection. Defaults to ``None``. autoload - Defaults to False: the Columns for this table should be - reflected from the database. Usually there will be no - Column objects in the constructor if this property is set. + Defaults to False: the Columns for this table should be reflected + from the database. Usually there will be no Column objects in the + constructor if this property is set. autoload_with if autoload==True, this is an optional Engine or Connection - instance to be used for the table reflection. If ``None``, - the underlying MetaData's bound connectable will be used. - + instance to be used for the table reflection. If ``None``, the + underlying MetaData's bound connectable will be used. + include_columns - A list of strings indicating a subset of columns to be - loaded via the ``autoload`` operation; table columns who - aren't present in this list will not be represented on the resulting - ``Table`` object. Defaults to ``None`` which indicates all - columns should be reflected. - + A list of strings indicating a subset of columns to be loaded via + the ``autoload`` operation; table columns who aren't present in + this list will not be represented on the resulting ``Table`` + object. Defaults to ``None`` which indicates all columns should + be reflected. + + info + Defaults to {}: A space to store application specific data; this + must be a dictionary. + mustexist - Defaults to False: indicates that this Table must already - have been defined elsewhere in the application, else an - exception is raised. + Defaults to False: indicates that this Table must already have + been defined elsewhere in the application, else an exception is + raised. useexisting - Defaults to False: indicates that if this Table was - already defined elsewhere in the application, disregard - the rest of the constructor arguments. + Defaults to False: indicates that if this Table was already + defined elsewhere in the application, disregard the rest of the + constructor arguments. owner - Defaults to None: optional owning user of this table. - useful for databases such as Oracle to aid in table - reflection. + Deprecated; this is an oracle-only argument - "schema" should + be used in its place. quote - Defaults to False: indicates that the Table identifier - must be properly escaped and quoted before being sent to - the database. This flag overrides all other quoting - behavior. + When True, indicates that the Table identifier must be quoted. + This flag does *not* disable quoting; for case-insensitive names, + use an all lower case identifier. quote_schema - Defaults to False: indicates that the Namespace identifier - must be properly escaped and quoted before being sent to - the database. This flag overrides all other quoting - behavior. - - case_sensitive - Defaults to True: indicates quoting should be used if the - identifier contains mixed case. - - case_sensitive_schema - Defaults to True: indicates quoting should be used if the - identifier contains mixed case. + When True, indicates that the schema identifier must be quoted. + This flag does *not* disable quoting; for case-insensitive names, + use an all lower case identifier. """ + super(Table, self).__init__(name) - self._metadata = metadata - self.schema = kwargs.pop('schema', None) + self.metadata = metadata + self.schema = kwargs.pop('schema', kwargs.pop('owner', None)) self.indexes = util.Set() self.constraints = util.Set() - self._columns = sql.ColumnCollection() + self._columns = expression.ColumnCollection() self.primary_key = PrimaryKeyConstraint() self._foreign_keys = util.OrderedSet() - self.quote = kwargs.pop('quote', False) - self.quote_schema = kwargs.pop('quote_schema', False) + self.ddl_listeners = util.defaultdict(list) + self.kwargs = {} if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) else: self.fullname = self.name - self.owner = kwargs.pop('owner', None) - self._set_casing_strategy(kwargs) - self._set_casing_strategy(kwargs, keyname='case_sensitive_schema') + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', None) + include_columns = kwargs.pop('include_columns', None) + + self._set_parent(metadata) + self.__extra_kwargs(**kwargs) + + # load column definitions from the database if 'autoload' is defined + # we do it after the table is in the singleton dictionary to support + # circular foreign keys + if autoload: + if autoload_with: + autoload_with.reflecttable(self, include_columns=include_columns) + else: + _bind_or_error(metadata).reflecttable(self, include_columns=include_columns) + + # initialize all the column, etc. objects. done after reflection to + # allow user-overrides + self.__post_init(*args, **kwargs) + + def _init_existing(self, *args, **kwargs): + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', None) + schema = kwargs.pop('schema', None) + if schema and schema != self.schema: + raise exceptions.ArgumentError( + "Can't change schema of existing table from '%s' to '%s'", + (self.schema, schema)) + + include_columns = kwargs.pop('include_columns', None) + if include_columns: + for c in self.c: + if c.name not in include_columns: + self.c.remove(c) + + self.__extra_kwargs(**kwargs) + self.__post_init(*args, **kwargs) + + def _cant_override(self, *args, **kwargs): + """Return True if any argument is not supported as an override. + + Takes arguments that would be sent to Table.__init__, and returns + True if any of them would be disallowed if sent to an existing + Table singleton. + """ + return bool(args) or bool(util.Set(kwargs).difference( + ['autoload', 'autoload_with', 'schema', 'owner'])) + + def __extra_kwargs(self, **kwargs): + self.quote = kwargs.pop('quote', False) + self.quote_schema = kwargs.pop('quote_schema', False) + if kwargs.get('info'): + self._info = kwargs.pop('info') + + # validate remaining kwargs that they all specify DB prefixes if len([k for k in kwargs if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]): raise TypeError("Invalid argument(s) for Table: %s" % repr(kwargs.keys())) + self.kwargs.update(kwargs) - # store extra kwargs, which should only contain db-specific options - self.kwargs = kwargs - key = property(lambda self:_get_table_key(self.name, self.schema)) - - def _export_columns(self, columns=None): - # override FromClause's collection initialization logic; TableClause and Table - # implement it differently - pass + def __post_init(self, *args, **kwargs): + self._init_items(*args) - def _get_case_sensitive_schema(self): - try: - return getattr(self, '_case_sensitive_schema') - except AttributeError: - setattr(self, '_case_sensitive_schema', self._determine_case_sensitive(keyname='case_sensitive_schema')) - return getattr(self, '_case_sensitive_schema') - case_sensitive_schema = property(_get_case_sensitive_schema) + def key(self): + return _get_table_key(self.name, self.schema) + key = property(key) def _set_primary_key(self, pk): if getattr(self, '_primary_key', None) in self.constraints: self.constraints.remove(self._primary_key) self._primary_key = pk self.constraints.add(pk) - primary_key = property(lambda s:s._primary_key, _set_primary_key) - def _derived_metadata(self): - return self._metadata + def primary_key(self): + return self._primary_key + primary_key = property(primary_key, _set_primary_key) def __repr__(self): - return "Table(%s)" % string.join( + return "Table(%s)" % ', '.join( [repr(self.name)] + [repr(self.metadata)] + [repr(x) for x in self.columns] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']] - , ',') + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]) def __str__(self): - return _get_table_key(self.encodedname, self.schema) + return _get_table_key(self.description, self.schema) def append_column(self, column): """Append a ``Column`` to this ``Table``.""" @@ -345,16 +305,45 @@ class Table(SchemaItem, sql.TableClause): constraint._set_parent(self) - def _get_parent(self): - return self._metadata + def append_ddl_listener(self, event, listener): + """Append a DDL event listener to this ``Table``. + + The ``listener`` callable will be triggered when this ``Table`` is + created or dropped, either directly before or after the DDL is issued + to the database. The listener may modify the Table, but may not abort + the event itself. + + Arguments are: + + event + One of ``Table.ddl_events``; e.g. 'before-create', 'after-create', + 'before-drop' or 'after-drop'. + + listener + A callable, invoked with three positional arguments: + + event + The event currently being handled + schema_item + The ``Table`` object being created or dropped + bind + The ``Connection`` bueing used for DDL execution. + + Listeners are added to the Table's ``ddl_listeners`` attribute. + """ + + if event not in self.ddl_events: + raise LookupError(event) + self.ddl_listeners[event].append(listener) def _set_parent(self, metadata): metadata.tables[_get_table_key(self.name, self.schema)] = self - self._metadata = metadata + self.metadata = metadata def get_children(self, column_collections=True, schema_visitor=False, **kwargs): if not schema_visitor: - return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs) + return expression.TableClause.get_children( + self, column_collections=column_collections, **kwargs) else: if column_collections: return [c for c in self.columns] @@ -365,7 +354,7 @@ class Table(SchemaItem, sql.TableClause): """Return True if this table exists.""" if bind is None: - bind = self._get_engine(raiseerr=True) + bind = _bind_or_error(self) def do(conn): return conn.dialect.has_table(conn, self.name, schema=self.schema) @@ -374,15 +363,15 @@ class Table(SchemaItem, sql.TableClause): def create(self, bind=None, checkfirst=False): """Issue a ``CREATE`` statement for this table. - See also ``metadata.create_all()``.""" - + See also ``metadata.create_all()``. + """ self.metadata.create_all(bind=bind, checkfirst=checkfirst, tables=[self]) def drop(self, bind=None, checkfirst=False): """Issue a ``DROP`` statement for this table. - See also ``metadata.drop_all()``.""" - + See also ``metadata.drop_all()``. + """ self.metadata.drop_all(bind=bind, checkfirst=checkfirst, tables=[self]) def tometadata(self, metadata, schema=None): @@ -401,159 +390,168 @@ class Table(SchemaItem, sql.TableClause): args.append(c.copy()) return Table(self.name, metadata, schema=schema, *args) -class Column(SchemaItem, sql._ColumnClause): +class Column(SchemaItem, expression._ColumnClause): """Represent a column in a database table. - This is a subclass of ``sql.ColumnClause`` and represents an - actual existing table in the database, in a similar fashion as + This is a subclass of ``expression.ColumnClause`` and represents an actual + existing table in the database, in a similar fashion as ``TableClause``/``Table``. """ - def __init__(self, name, type_, *args, **kwargs): + def __init__(self, *args, **kwargs): """Construct a new ``Column`` object. Arguments are: name - The name of this column. This should be the identical name - as it appears, or will appear, in the database. + The name of this column. This should be the identical name as it + appears, or will appear, in the database. Name may be omitted at + construction time but must be assigned before adding a Column + instance to a Table. type\_ - The ``TypeEngine`` for this column. This can be any - subclass of ``types.AbstractType``, including the - database-agnostic types defined in the types module, - database-specific types defined within specific database - modules, or user-defined types. If the column contains a - ForeignKey, the type can also be None, in which case the + The ``TypeEngine`` for this column. This can be any subclass of + ``types.AbstractType``, including the database-agnostic types + defined in the types module, database-specific types defined within + specific database modules, or user-defined types. If the column + contains a ForeignKey, the type can also be None, in which case the type assigned will be that of the referenced column. \*args - Constraint, ForeignKey, ColumnDefault and Sequence objects - should be added as list values. + Constraint, ForeignKey, ColumnDefault and Sequence objects should be + added as list values. \**kwargs Keyword arguments include: key - Defaults to None: an optional *alias name* for this column. - The column will then be identified everywhere in an - application, including the column list on its Table, by - this key, and not the given name. Generated SQL, however, - will still reference the column by its actual name. + Defaults to the column name: a Python-only *alias name* for this + column. + + The column will then be identified everywhere in an application, + including the column list on its Table, by this key, and not the + given name. Generated SQL, however, will still reference the + column by its actual name. primary_key - Defaults to False: True if this column is a primary key - column. Multiple columns can have this flag set to - specify composite primary keys. As an alternative, the - primary key of a Table can be specified via an explicit - ``PrimaryKeyConstraint`` instance appended to the Table's - list of objects. + Defaults to False: True if this column is a primary key column. + Multiple columns can have this flag set to specify composite + primary keys. As an alternative, the primary key of a Table can + be specified via an explicit ``PrimaryKeyConstraint`` instance + appended to the Table's list of objects. nullable - Defaults to True : True if this column should allow - nulls. True is the default unless this column is a primary - key column. + Defaults to True : True if this column should allow nulls. True is + the default unless this column is a primary key column. default Defaults to None: a scalar, Python callable, or ``ClauseElement`` - representing the *default value* for this column, which will - be invoked upon insert if this column is not present in - the insert list or is given a value of None. The default - expression will be converted into a ``ColumnDefault`` object - upon initialization. + representing the *default value* for this column, which will be + invoked upon insert if this column is not present in the insert + list or is given a value of None. The default expression will be + converted into a ``ColumnDefault`` object upon initialization. _is_oid - Defaults to False: used internally to indicate that this - column is used as the quasi-hidden "oid" column + Defaults to False: used internally to indicate that this column is + used as the quasi-hidden "oid" column index - Defaults to False: indicates that this column is - indexed. The name of the index is autogenerated. to - specify indexes with explicit names or indexes that - contain multiple columns, use the ``Index`` construct instead. + Defaults to False: indicates that this column is indexed. The name + of the index is autogenerated. to specify indexes with explicit + names or indexes that contain multiple columns, use the ``Index`` + construct instead. + + info + Defaults to {}: A space to store application specific data; this + must be a dictionary. unique - Defaults to False: indicates that this column contains a - unique constraint, or if `index` is True as well, - indicates that the Index should be created with the unique - flag. To specify multiple columns in the constraint/index - or to specify an explicit name, use the - ``UniqueConstraint`` or ``Index`` constructs instead. + Defaults to False: indicates that this column contains a unique + constraint, or if `index` is True as well, indicates that the + Index should be created with the unique flag. To specify multiple + columns in the constraint/index or to specify an explicit name, + use the ``UniqueConstraint`` or ``Index`` constructs instead. autoincrement - Defaults to True: indicates that integer-based primary key - columns should have autoincrementing behavior, if - supported by the underlying database. This will affect - ``CREATE TABLE`` statements such that they will use the - databases *auto-incrementing* keyword (such as ``SERIAL`` - for Postgres, ``AUTO_INCREMENT`` for Mysql) and will also - affect the behavior of some dialects during ``INSERT`` - statement execution such that they will assume primary key - values are created in this manner. If a ``Column`` has an + Defaults to True: indicates that integer-based primary key columns + should have autoincrementing behavior, if supported by the + underlying database. This will affect ``CREATE TABLE`` statements + such that they will use the databases *auto-incrementing* keyword + (such as ``SERIAL`` for Postgres, ``AUTO_INCREMENT`` for Mysql) + and will also affect the behavior of some dialects during + ``INSERT`` statement execution such that they will assume primary + key values are created in this manner. If a ``Column`` has an explicit ``ColumnDefault`` object (such as via the `default` - keyword, or a ``Sequence`` or ``PassiveDefault``), then - the value of `autoincrement` is ignored and is assumed to be - False. `autoincrement` value is only significant for a - column with a type or subtype of Integer. + keyword, or a ``Sequence`` or ``PassiveDefault``), then the value + of `autoincrement` is ignored and is assumed to be False. + `autoincrement` value is only significant for a column with a type + or subtype of Integer. quote - Defaults to False: indicates that the Column identifier - must be properly escaped and quoted before being sent to - the database. This flag should normally not be required - as dialects can auto-detect conditions where quoting is - required. - - case_sensitive - Defaults to True: indicates quoting should be used if the - identifier contains mixed case. + When True, indicates that the Column identifier must be quoted. + This flag does *not* disable quoting; for case-insensitive names, + use an all lower case identifier. """ + name = kwargs.pop('name', None) + type_ = kwargs.pop('type_', None) + if args: + args = list(args) + if isinstance(args[0], basestring): + if name is not None: + raise exceptions.ArgumentError( + "May not pass name positionally and as a keyword.") + name = args.pop(0) + if args: + if (isinstance(args[0], types.AbstractType) or + (isinstance(args[0], type) and + issubclass(args[0], types.AbstractType))): + if type_ is not None: + raise exceptions.ArgumentError( + "May not pass type_ positionally and as a keyword.") + type_ = args.pop(0) + super(Column, self).__init__(name, None, type_) self.args = args self.key = kwargs.pop('key', name) - self._primary_key = kwargs.pop('primary_key', False) + self.primary_key = kwargs.pop('primary_key', False) self.nullable = kwargs.pop('nullable', not self.primary_key) self._is_oid = kwargs.pop('_is_oid', False) self.default = kwargs.pop('default', None) self.index = kwargs.pop('index', None) self.unique = kwargs.pop('unique', None) self.quote = kwargs.pop('quote', False) - self._set_casing_strategy(kwargs) self.onupdate = kwargs.pop('onupdate', None) self.autoincrement = kwargs.pop('autoincrement', True) self.constraints = util.Set() - self.__originating_column = self - self._foreign_keys = util.OrderedSet() - if len(kwargs): - raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) - - primary_key = util.SimpleProperty('_primary_key') - foreign_keys = util.SimpleProperty('_foreign_keys') - columns = property(lambda self:[self]) + self.foreign_keys = util.OrderedSet() + if kwargs.get('info'): + self._info = kwargs.pop('info') + if kwargs: + raise exceptions.ArgumentError( + "Unknown arguments passed to Column: " + repr(kwargs.keys())) def __str__(self): if self.table is not None: - if self.table.named_with_column(): - return (self.table.encodedname + "." + self.encodedname) + if self.table.named_with_column: + return (self.table.description + "." + self.description) else: - return self.encodedname + return self.description else: - return self.encodedname - - def _derived_metadata(self): - return self.table.metadata + return self.description - def _get_engine(self): + def bind(self): return self.table.bind + bind = property(bind) def references(self, column): - """return true if this column references the given column via foreign key""" + """Return True if this references the given column via a foreign key.""" for fk in self.foreign_keys: - if fk.column is column: + if fk.references(column.table): return True else: return False - + def append_foreign_key(self, fk): fk._set_parent(self) @@ -561,7 +559,7 @@ class Column(SchemaItem, sql._ColumnClause): kwarg = [] if self.key != self.name: kwarg.append('key') - if self._primary_key: + if self.primary_key: kwarg.append('primary_key') if not self.nullable: kwarg.append('nullable') @@ -569,36 +567,56 @@ class Column(SchemaItem, sql._ColumnClause): kwarg.append('onupdate') if self.default: kwarg.append('default') - return "Column(%s)" % string.join( + return "Column(%s)" % ', '.join( [repr(self.name)] + [repr(self.type)] + [repr(x) for x in self.foreign_keys if x is not None] + [repr(x) for x in self.constraints] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg] - , ',') - - def _get_parent(self): - return self.table + [(self.table and "table=<%s>" % self.table.description or "")] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]) def _set_parent(self, table): + if self.name is None: + raise exceptions.ArgumentError( + "Column must be constructed with a name or assign .name " + "before adding to a Table.") + if self.key is None: + self.key = self.name + self.metadata = table.metadata if getattr(self, 'table', None) is not None: raise exceptions.ArgumentError("this Column already has a table!") if not self._is_oid: - table._columns.add(self) + self._pre_existing_column = table._columns.get(self.key) + + table._columns.replace(self) + else: + self._pre_existing_column = None + if self.primary_key: - table.primary_key.add(self) + table.primary_key.replace(self) elif self.key in table.primary_key: - raise exceptions.ArgumentError("Trying to redefine primary-key column '%s' as a non-primary-key column on table '%s'" % (self.key, table.fullname)) + raise exceptions.ArgumentError( + "Trying to redefine primary-key column '%s' as a " + "non-primary-key column on table '%s'" % ( + self.key, table.fullname)) # if we think this should not raise an error, we'd instead do this: #table.primary_key.remove(self) self.table = table if self.index: if isinstance(self.index, basestring): - raise exceptions.ArgumentError("The 'index' keyword argument on Column is boolean only. To create indexes with a specific name, create an explicit Index object external to the Table.") + raise exceptions.ArgumentError( + "The 'index' keyword argument on Column is boolean only. " + "To create indexes with a specific name, create an " + "explicit Index object external to the Table.") Index('ix_%s' % self._label, self, unique=self.unique) elif self.unique: if isinstance(self.unique, basestring): - raise exceptions.ArgumentError("The 'unique' keyword argument on Column is boolean only. To create unique constraints or indexes with a specific name, append an explicit UniqueConstraint to the Table's list of elements, or create an explicit Index object external to the Table.") + raise exceptions.ArgumentError( + "The 'unique' keyword argument on Column is boolean only. " + "To create unique constraints or indexes with a specific " + "name, append an explicit UniqueConstraint to the Table's " + "list of elements, or create an explicit Index object " + "external to the Table.") table.append_constraint(UniqueConstraint(self.key)) toinit = list(self.args) @@ -615,7 +633,7 @@ class Column(SchemaItem, sql._ColumnClause): This is used in ``Table.tometadata``. """ - return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, case_sensitive=self._case_sensitive_setting, quote=self.quote, index=self.index, *[c.copy() for c in self.constraints]) + return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, index=self.index, autoincrement=self.autoincrement, *[c.copy() for c in self.constraints]) def _make_proxy(self, selectable, name = None): """Create a *proxy* for this column. @@ -627,9 +645,8 @@ class Column(SchemaItem, sql._ColumnClause): fk = [ForeignKey(f._colspec) for f in self.foreign_keys] c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk) c.table = selectable - c.orig_set = self.orig_set - c.__originating_column = self.__originating_column - c._distance = self._distance + 1 + c.proxies = [self] + c._pre_existing_column = self._pre_existing_column if not c._is_oid: selectable.columns.add(c) if self.primary_key: @@ -638,43 +655,62 @@ class Column(SchemaItem, sql._ColumnClause): return c - def _case_sens(self): - """Redirect the `case_sensitive` accessor to use the ultimate - parent column which created this one.""" - - return self.__originating_column._get_case_sensitive() - case_sensitive = property(_case_sens, lambda s,v:None) - def get_children(self, schema_visitor=False, **kwargs): if schema_visitor: return [x for x in (self.default, self.onupdate) if x is not None] + \ list(self.foreign_keys) + list(self.constraints) else: - return sql._ColumnClause.get_children(self, **kwargs) + return expression._ColumnClause.get_children(self, **kwargs) class ForeignKey(SchemaItem): - """Defines a column-level ``ForeignKey`` constraint between two columns. + """Defines a column-level FOREIGN KEY constraint between two columns. - ``ForeignKey`` is specified as an argument to a Column object. + ``ForeignKey`` is specified as an argument to a ``Column`` object. - One or more ``ForeignKey`` objects are used within a - ``ForeignKeyConstraint`` object which represents the table-level - constraint definition. + For a composite (multiple column) FOREIGN KEY, use a ForeignKeyConstraint + within the Table definition. """ - def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None): - """Construct a new ``ForeignKey`` object. + def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None): + """Construct a column-level FOREIGN KEY. column - Can be a ``schema.Column`` object representing the relationship, - or just its string name given as ``tablename.columnname``. - schema can be specified as ``schema.tablename.columnname``. + A single target column for the key relationship. A ``Column`` + object or a column name as a string: ``tablename.columnname`` or + ``schema.tablename.columnname``. constraint - Is the owning ``ForeignKeyConstraint`` object, if any. if not - given, then a ``ForeignKeyConstraint`` will be automatically - created and added to the parent table. + Optional. A parent ``ForeignKeyConstraint`` object. If not + supplied, a ``ForeignKeyConstraint`` will be automatically created + and added to the parent table. + + name + Optional string. An in-database name for the key if `constraint` is + not provided. + + onupdate + Optional string. If set, emit ON UPDATE when issuing DDL + for this constraint. Typical values include CASCADE, DELETE and + RESTRICT. + + ondelete + Optional string. If set, emit ON DELETE when issuing DDL + for this constraint. Typical values include CASCADE, DELETE and + RESTRICT. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + use_alter + If True, do not emit this key as part of the CREATE TABLE + definition. Instead, use ALTER TABLE after table creation to add + the key. Useful for circular dependencies. """ self._colspec = column @@ -684,6 +720,8 @@ class ForeignKey(SchemaItem): self.name = name self.onupdate = onupdate self.ondelete = ondelete + self.deferrable = deferrable + self.initially = initially def __repr__(self): return "ForeignKey(%s)" % repr(self._get_colspec()) @@ -697,67 +735,100 @@ class ForeignKey(SchemaItem): if isinstance(self._colspec, basestring): return self._colspec elif self._colspec.table.schema is not None: - return "%s.%s.%s" % (self._colspec.table.schema, self._colspec.table.name, self._colspec.key) + return "%s.%s.%s" % (self._colspec.table.schema, + self._colspec.table.name, self._colspec.key) else: return "%s.%s" % (self._colspec.table.name, self._colspec.key) def references(self, table): - """Return True if the given table is referenced by this ``ForeignKey``.""" + """Return True if the given table is referenced by this ForeignKey.""" - return table.corresponding_column(self.column, False) is not None - - def _init_column(self): - # ForeignKey inits its remote column as late as possible, so tables can - # be defined without dependencies + return table.corresponding_column(self.column) is not None + + def get_referent(self, table): + """Return the column in the given table referenced by this ForeignKey. + + Returns None if this ``ForeignKey`` does not reference the given table. + """ + return table.corresponding_column(self.column) + + def column(self): + # ForeignKey inits its remote column as late as possible, so tables + # can be defined without dependencies if self._column is None: if isinstance(self._colspec, basestring): - # locate the parent table this foreign key is attached to. - # we use the "original" column which our parent column represents - # (its a list of columns/other ColumnElements if the parent table is a UNION) - for c in self.parent.orig_set: + # locate the parent table this foreign key is attached to. we + # use the "original" column which our parent column represents + # (its a list of columns/other ColumnElements if the parent + # table is a UNION) + for c in self.parent.base_columns: if isinstance(c, Column): parenttable = c.table break else: - raise exceptions.ArgumentError("Parent column '%s' does not descend from a table-attached Column" % str(self.parent)) - m = re.match(r"^(.+?)(?:\.(.+?))?(?:\.(.+?))?$", self._colspec, re.UNICODE) + raise exceptions.ArgumentError( + "Parent column '%s' does not descend from a " + "table-attached Column" % str(self.parent)) + m = re.match(r"^(.+?)(?:\.(.+?))?(?:\.(.+?))?$", self._colspec, + re.UNICODE) if m is None: - raise exceptions.ArgumentError("Invalid foreign key column specification: " + self._colspec) + raise exceptions.ArgumentError( + "Invalid foreign key column specification: %s" % + self._colspec) if m.group(3) is None: (tname, colname) = m.group(1, 2) schema = None else: (schema,tname,colname) = m.group(1,2,3) - table = Table(tname, parenttable.metadata, mustexist=True, schema=schema) + if _get_table_key(tname, schema) not in parenttable.metadata: + raise exceptions.NoReferencedTableError( + "Could not find table '%s' with which to generate a " + "foreign key" % tname) + table = Table(tname, parenttable.metadata, + mustexist=True, schema=schema) try: if colname is None: - # colname is None in the case that ForeignKey argument was specified - # as table name only, in which case we match the column name to the same - # column on the parent. + # colname is None in the case that ForeignKey argument + # was specified as table name only, in which case we + # match the column name to the same column on the + # parent. key = self.parent self._column = table.c[self.parent.key] else: self._column = table.c[colname] except KeyError, e: - raise exceptions.ArgumentError("Could not create ForeignKey '%s' on table '%s': table '%s' has no column named '%s'" % (self._colspec, parenttable.name, table.name, str(e))) + raise exceptions.ArgumentError( + "Could not create ForeignKey '%s' on table '%s': " + "table '%s' has no column named '%s'" % ( + self._colspec, parenttable.name, table.name, str(e))) + + elif isinstance(self._colspec, expression.Operators): + self._column = self._colspec.clause_element() else: self._column = self._colspec - - # propigate TypeEngine to parent if it didnt have one + + # propagate TypeEngine to parent if it didn't have one if isinstance(self.parent.type, types.NullType): self.parent.type = self._column.type return self._column - column = property(lambda s: s._init_column()) - - def _get_parent(self): - return self.parent + column = property(column) def _set_parent(self, column): self.parent = column + if self.parent._pre_existing_column is not None: + # remove existing FK which matches us + for fk in self.parent._pre_existing_column.foreign_keys: + if fk._colspec == self._colspec: + self.parent.table.foreign_keys.remove(fk) + self.parent.table.constraints.remove(fk.constraint) + if self.constraint is None and isinstance(self.parent.table, Table): - self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete) + self.constraint = ForeignKeyConstraint( + [], [], use_alter=self.use_alter, name=self.name, + onupdate=self.onupdate, ondelete=self.ondelete, + deferrable=self.deferrable, initially=self.initially) self.parent.table.append_constraint(self.constraint) self.constraint._append_fk(self) @@ -769,20 +840,11 @@ class DefaultGenerator(SchemaItem): def __init__(self, for_update=False, metadata=None): self.for_update = for_update - self._metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata') - - def _derived_metadata(self): - try: - return self.column.table.metadata - except AttributeError: - return self._metadata - - def _get_parent(self): - return getattr(self, 'column', None) + self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata') def _set_parent(self, column): self.column = column - self._metadata = self.column.table.metadata + self.metadata = self.column.table.metadata if self.for_update: self.column.onupdate = self else: @@ -790,7 +852,7 @@ class DefaultGenerator(SchemaItem): def execute(self, bind=None, **kwargs): if bind is None: - bind = self._get_engine(raiseerr=True) + bind = _bind_or_error(self) return bind._execute_default(self, **kwargs) def __repr__(self): @@ -809,25 +871,46 @@ class PassiveDefault(DefaultGenerator): class ColumnDefault(DefaultGenerator): """A plain default value on a column. - This could correspond to a constant, a callable function, or a SQL - clause. + This could correspond to a constant, a callable function, or a SQL clause. """ def __init__(self, arg, **kwargs): super(ColumnDefault, self).__init__(**kwargs) if callable(arg): - if not inspect.isfunction(arg): - self.arg = lambda ctx: arg() - else: - argspec = inspect.getargspec(arg) - if len(argspec[0]) == 0: - self.arg = lambda ctx: arg() - elif len(argspec[0]) != 1: - raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments") - else: - self.arg = arg + arg = self._maybe_wrap_callable(arg) + self.arg = arg + + def _maybe_wrap_callable(self, fn): + """Backward compat: Wrap callables that don't accept a context.""" + + if inspect.isfunction(fn): + inspectable = fn + elif inspect.isclass(fn): + inspectable = fn.__init__ + elif hasattr(fn, '__call__'): + inspectable = fn.__call__ else: - self.arg = arg + # probably not inspectable, try anyways. + inspectable = fn + try: + argspec = inspect.getargspec(inspectable) + except TypeError: + return lambda ctx: fn() + + positionals = len(argspec[0]) + if inspect.ismethod(inspectable): + positionals -= 1 + + if positionals == 0: + return lambda ctx: fn() + + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + if positionals - defaulted > 1: + raise exceptions.ArgumentError( + "ColumnDefault Python function takes zero or one " + "positional arguments") + return fn + def _visit_name(self): if self.for_update: @@ -840,48 +923,70 @@ class ColumnDefault(DefaultGenerator): return "ColumnDefault(%s)" % repr(self.arg) class Sequence(DefaultGenerator): - """Represent a sequence, which applies to Oracle and Postgres databases.""" + """Represents a named database sequence.""" - def __init__(self, name, start = None, increment = None, optional=False, quote=False, **kwargs): + def __init__(self, name, start=None, increment=None, schema=None, + optional=False, quote=False, **kwargs): super(Sequence, self).__init__(**kwargs) self.name = name self.start = start self.increment = increment self.optional=optional self.quote = quote - self._set_casing_strategy(kwargs) + self.schema = schema + self.kwargs = kwargs def __repr__(self): - return "Sequence(%s)" % string.join( + return "Sequence(%s)" % ', '.join( [repr(self.name)] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['start', 'increment', 'optional']] - , ',') + ["%s=%s" % (k, repr(getattr(self, k))) + for k in ['start', 'increment', 'optional']]) def _set_parent(self, column): super(Sequence, self)._set_parent(column) column.sequence = self def create(self, bind=None, checkfirst=True): + """Creates this sequence in the database.""" + if bind is None: - bind = self._get_engine(raiseerr=True) + bind = _bind_or_error(self) bind.create(self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=True): + """Drops this sequence from the database.""" + if bind is None: - bind = self._get_engine(raiseerr=True) + bind = _bind_or_error(self) bind.drop(self, checkfirst=checkfirst) class Constraint(SchemaItem): - """Represent a table-level ``Constraint`` such as a composite primary key, foreign key, or unique constraint. + """A table-level SQL constraint, such as a KEY. - Implements a hybrid of dict/setlike behavior with regards to the - list of underying columns. + Implements a hybrid of dict/setlike behavior with regards to the list of + underying columns. """ - def __init__(self, name=None): + def __init__(self, name=None, deferrable=None, initially=None): + """Create a SQL constraint. + + name + Optional, the in-database name of this ``Constraint``. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + """ + self.name = name - self.columns = sql.ColumnCollection() + self.columns = expression.ColumnCollection() + self.deferrable = deferrable + self.initially = initially def __contains__(self, x): return self.columns.contains_column(x) @@ -901,20 +1006,43 @@ class Constraint(SchemaItem): def copy(self): raise NotImplementedError() - def _get_parent(self): - return getattr(self, 'table', None) - class CheckConstraint(Constraint): - def __init__(self, sqltext, name=None): - super(CheckConstraint, self).__init__(name) + """A table- or column-level CHECK constraint. + + Can be included in the definition of a Table or Column. + """ + + def __init__(self, sqltext, name=None, deferrable=None, initially=None): + """Construct a CHECK constraint. + + sqltext + A string containing the constraint definition. Will be used + verbatim. + + name + Optional, the in-database name of the constraint. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + """ + + super(CheckConstraint, self).__init__(name, deferrable, initially) + if not isinstance(sqltext, basestring): + raise exc.ArgumentError( + "sqltext must be a string and will be used verbatim.") self.sqltext = sqltext - def _visit_name(self): + def __visit_name__(self): if isinstance(self.parent, Table): return "check_constraint" else: return "column_check_constraint" - __visit_name__ = property(_visit_name) + __visit_name__ = property(__visit_name__) def _set_parent(self, parent): self.parent = parent @@ -924,10 +1052,53 @@ class CheckConstraint(Constraint): return CheckConstraint(self.sqltext, name=self.name) class ForeignKeyConstraint(Constraint): - """Table-level foreign key constraint, represents a collection of ``ForeignKey`` objects.""" + """A table-level FOREIGN KEY constraint. + + Defines a single column or composite FOREIGN KEY ... REFERENCES + constraint. For a no-frills, single column foreign key, adding a + ``ForeignKey`` to the definition of a ``Column`` is a shorthand equivalent + for an unnamed, single column ``ForeignKeyConstraint``. + """ + + def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False, deferrable=None, initially=None): + """Construct a composite-capable FOREIGN KEY. + + columns + A sequence of local column names. The named columns must be defined + and present in the parent Table. + + refcolumns + A sequence of foreign column names or Column objects. The columns + must all be located within the same Table. + + name + Optional, the in-database name of the key. + + onupdate + Optional string. If set, emit ON UPDATE when issuing DDL + for this constraint. Typical values include CASCADE, DELETE and + RESTRICT. + + ondelete + Optional string. If set, emit ON DELETE when issuing DDL + for this constraint. Typical values include CASCADE, DELETE and + RESTRICT. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + use_alter + If True, do not emit this key as part of the CREATE TABLE + definition. Instead, use ALTER TABLE after table creation to add + the key. Useful for circular dependencies. + """ - def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False): - super(ForeignKeyConstraint, self).__init__(name) + super(ForeignKeyConstraint, self).__init__(name, deferrable, initially) self.__colnames = columns self.__refcolnames = refcolumns self.elements = util.OrderedSet() @@ -939,9 +1110,10 @@ class ForeignKeyConstraint(Constraint): def _set_parent(self, table): self.table = table - table.constraints.add(self) - for (c, r) in zip(self.__colnames, self.__refcolnames): - self.append_element(c,r) + if self not in table.constraints: + table.constraints.add(self) + for (c, r) in zip(self.__colnames, self.__refcolnames): + self.append_element(c,r) def append_element(self, col, refcol): fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter) @@ -956,27 +1128,62 @@ class ForeignKeyConstraint(Constraint): return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter) class PrimaryKeyConstraint(Constraint): + """A table-level PRIMARY KEY constraint. + + Defines a single column or composite PRIMARY KEY constraint. For a + no-frills primary key, adding ``primary_key=True`` to one or more + ``Column`` definitions is a shorthand equivalent for an unnamed single- or + multiple-column PrimaryKeyConstraint. + """ + def __init__(self, *columns, **kwargs): - super(PrimaryKeyConstraint, self).__init__(name=kwargs.pop('name', None)) + """Construct a composite-capable PRIMARY KEY. + + \*columns + A sequence of column names. All columns named must be defined and + present within the parent Table. + + name + Optional, the in-database name of the key. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + """ + + constraint_args = dict(name=kwargs.pop('name', None), + deferrable=kwargs.pop('deferrable', None), + initially=kwargs.pop('initially', None)) + if kwargs: + raise exceptions.ArgumentError( + 'Unknown PrimaryKeyConstraint argument(s): %s' % + ', '.join([repr(x) for x in kwargs.keys()])) + + super(PrimaryKeyConstraint, self).__init__(**constraint_args) self.__colnames = list(columns) def _set_parent(self, table): self.table = table table.primary_key = self - for c in self.__colnames: - self.append_column(table.c[c]) + for name in self.__colnames: + self.add(table.c[name]) def add(self, col): - self.append_column(col) + self.columns.add(col) + col.primary_key=True + append_column = add + + def replace(self, col): + self.columns.replace(col) def remove(self, col): col.primary_key=False del self.columns[col.key] - def append_column(self, col): - self.columns.add(col) - col.primary_key=True - def copy(self): return PrimaryKeyConstraint(name=self.name, *[c.key for c in self]) @@ -984,8 +1191,42 @@ class PrimaryKeyConstraint(Constraint): return self.columns == other class UniqueConstraint(Constraint): + """A table-level UNIQUE constraint. + + Defines a single column or composite UNIQUE constraint. For a no-frills, + single column constraint, adding ``unique=True`` to the ``Column`` + definition is a shorthand equivalent for an unnamed, single column + UniqueConstraint. + """ + def __init__(self, *columns, **kwargs): - super(UniqueConstraint, self).__init__(name=kwargs.pop('name', None)) + """Construct a UNIQUE constraint. + + \*columns + A sequence of column names. All columns named must be defined and + present within the parent Table. + + name + Optional, the in-database name of the key. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + """ + + constraint_args = dict(name=kwargs.pop('name', None), + deferrable=kwargs.pop('deferrable', None), + initially=kwargs.pop('initially', None)) + if kwargs: + raise exceptions.ArgumentError( + 'Unknown UniqueConstraint argument(s): %s' % + ', '.join([repr(x) for x in kwargs.keys()])) + + super(UniqueConstraint, self).__init__(**constraint_args) self.__colnames = list(columns) def _set_parent(self, table): @@ -1001,7 +1242,12 @@ class UniqueConstraint(Constraint): return UniqueConstraint(name=self.name, *self.__colnames) class Index(SchemaItem): - """Represent an index of columns from a database table.""" + """A table-level INDEX. + + Defines a composite (one or more column) INDEX. For a no-frills, single + column index, adding ``index=True`` to the ``Column`` definition is + a shorthand equivalent for an unnamed, single column Index. + """ def __init__(self, name, *columns, **kwargs): """Construct an index object. @@ -1012,35 +1258,34 @@ class Index(SchemaItem): The name of the index \*columns - Columns to include in the index. All columns must belong to - the same table, and no column may appear more than once. + Columns to include in the index. All columns must belong to the same + table, and no column may appear more than once. \**kwargs Keyword arguments include: unique - Defaults to True: create a unique index. - + Defaults to False: create a unique index. + + postgres_where + Defaults to None: create a partial index when using PostgreSQL """ self.name = name self.columns = [] self.table = None self.unique = kwargs.pop('unique', False) - self._init_items(*columns) + self.kwargs = kwargs - def _derived_metadata(self): - return self.table.metadata + self._init_items(*columns) def _init_items(self, *args): for column in args: self.append_column(column) - def _get_parent(self): - return self.table - def _set_parent(self, table): self.table = table + self.metadata = table.metadata table.indexes.add(self) def append_column(self, column): @@ -1050,28 +1295,25 @@ class Index(SchemaItem): self._set_parent(column.table) elif column.table != self.table: # all columns muse be from same table - raise exceptions.ArgumentError("All index columns must be from same table. " - "%s is from %s not %s" % (column, - column.table, - self.table)) + raise exceptions.ArgumentError( + "All index columns must be from same table. " + "%s is from %s not %s" % (column, column.table, self.table)) elif column.name in [ c.name for c in self.columns ]: - raise exceptions.ArgumentError("A column may not appear twice in the " - "same index (%s already has column %s)" - % (self.name, column)) + raise exceptions.ArgumentError( + "A column may not appear twice in the " + "same index (%s already has column %s)" % (self.name, column)) self.columns.append(column) def create(self, bind=None): - if bind is not None: - bind.create(self) - else: - self._get_engine(raiseerr=True).create(self) + if bind is None: + bind = _bind_or_error(self) + bind.create(self) return self def drop(self, bind=None): - if bind is not None: - bind.drop(self) - else: - self._get_engine(raiseerr=True).drop(self) + if bind is None: + bind = _bind_or_error(self) + bind.drop(self) def __str__(self): return repr(self) @@ -1083,182 +1325,635 @@ class Index(SchemaItem): (self.unique and ', unique=True') or '') class MetaData(SchemaItem): - """Represent a collection of Tables and their associated schema constructs.""" + """A collection of Tables and their associated schema constructs. + + Holds a collection of Tables and an optional binding to an ``Engine`` or + ``Connection``. If bound, the [sqlalchemy.schema#Table] objects in the + collection and their columns may participate in implicit SQL execution. + + The ``bind`` property may be assigned to dynamically. A common pattern is + to start unbound and then bind later when an engine is available:: + + metadata = MetaData() + # define tables + Table('mytable', metadata, ...) + # connect to an engine later, perhaps after loading a URL from a + # configuration file + metadata.bind = an_engine + + MetaData is a thread-safe object after tables have been explicitly defined + or loaded via reflection. + """ __visit_name__ = 'metadata' - - def __init__(self, bind=None, **kwargs): - """create a new MetaData object. - - bind - an Engine, or a string or URL instance which will be passed - to create_engine(), this MetaData will be bound to the resulting - engine. - - case_sensitive - popped from \**kwargs, indicates default case sensitive setting for - all contained objects. defaults to True. - - """ + + ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop') + + def __init__(self, bind=None, reflect=False): + """Create a new MetaData object. + + bind + An Engine or Connection to bind to. May also be a string or URL + instance, these are passed to create_engine() and this MetaData will + be bound to the resulting engine. + + reflect + Optional, automatically load all tables from the bound database. + Defaults to False. ``bind`` is required when this option is set. + For finer control over loaded tables, use the ``reflect`` method of + ``MetaData``. + """ self.tables = {} - self._set_casing_strategy(kwargs) self.bind = bind - + self.metadata = self + self.ddl_listeners = util.defaultdict(list) + if reflect: + if not bind: + raise exceptions.ArgumentError( + "A bind must be supplied in conjunction with reflect=True") + self.reflect() + def __repr__(self): return 'MetaData(%r)' % self.bind + def __contains__(self, key): + return key in self.tables + def __getstate__(self): - return {'tables':self.tables, 'casesensitive':self._case_sensitive_setting} + return {'tables': self.tables} def __setstate__(self, state): self.tables = state['tables'] - self._case_sensitive_setting = state['casesensitive'] self._bind = None - + def is_bound(self): - """return True if this MetaData is bound to an Engine.""" + """True if this MetaData is bound to an Engine or Connection.""" + return self._bind is not None + # @deprecated def connect(self, bind, **kwargs): - """bind this MetaData to an Engine. - - DEPRECATED. use metadata.bind = or metadata.bind = . - - bind - a string, URL or Engine instance. If a string or URL, - will be passed to create_engine() along with \**kwargs to - produce the engine which to connect to. otherwise connects - directly to the given Engine. + """Bind this MetaData to an Engine. + + Use ``metadata.bind = `` or ``metadata.bind = ``. + + bind + A string, ``URL``, ``Engine`` or ``Connection`` instance. If a + string or ``URL``, will be passed to ``create_engine()`` along with + ``\**kwargs`` to produce the engine which to connect to. Otherwise + connects directly to the given ``Engine``. + """ + + global URL + if URL is None: + from sqlalchemy.engine.url import URL + if isinstance(bind, (basestring, URL)): + from sqlalchemy import create_engine + self._bind = create_engine(bind, **kwargs) + else: + self._bind = bind + connect = util.deprecated()(connect) + def bind(self): + """An Engine or Connection to which this MetaData is bound. + + This property may be assigned an ``Engine`` or ``Connection``, or + assigned a string or URL to automatically create a basic ``Engine`` + for this bind with ``create_engine()``. """ - - from sqlalchemy.engine.url import URL + return self._bind + + def _bind_to(self, bind): + """Bind this MetaData to an Engine, Connection, string or URL.""" + + global URL + if URL is None: + from sqlalchemy.engine.url import URL + if isinstance(bind, (basestring, URL)): - self._bind = sqlalchemy.create_engine(bind, **kwargs) + from sqlalchemy import create_engine + self._bind = create_engine(bind) else: self._bind = bind + bind = property(bind, _bind_to) - bind = property(lambda self:self._bind, connect, doc="""an Engine or Connection to which this MetaData is bound. this is a settable property as well.""") - def clear(self): self.tables.clear() def remove(self, table): - # TODO: scan all other tables and remove FK _column + # TODO: scan all other tables and remove FK _column del self.tables[table.key] - + def table_iterator(self, reverse=True, tables=None): - import sqlalchemy.sql_util + from sqlalchemy.sql.util import sort_tables if tables is None: tables = self.tables.values() else: tables = util.Set(tables).intersection(self.tables.values()) - sorter = sqlalchemy.sql_util.TableCollection(list(tables)) - return iter(sorter.sort(reverse=reverse)) + return iter(sort_tables(tables, reverse=reverse)) + + def reflect(self, bind=None, schema=None, only=None): + """Load all available table definitions from the database. + + Automatically creates ``Table`` entries in this ``MetaData`` for any + table available in the database but not yet present in the + ``MetaData``. May be called multiple times to pick up tables recently + added to the database, however no special action is taken if a table + in this ``MetaData`` no longer exists in the database. + + bind + A ``Connectable`` used to access the database; if None, uses the + existing bind on this ``MetaData``, if any. + + schema + Optional, query and reflect tables from an alterate schema. + + only + Optional. Load only a sub-set of available named tables. May be + specified as a sequence of names or a callable. + + If a sequence of names is provided, only those tables will be + reflected. An error is raised if a table is requested but not + available. Named tables already present in this ``MetaData`` are + ignored. + + If a callable is provided, it will be used as a boolean predicate to + filter the list of potential table names. The callable is called + with a table name and this ``MetaData`` instance as positional + arguments and should return a true value for any table to reflect. + """ + + reflect_opts = {'autoload': True} + if bind is None: + bind = _bind_or_error(self) + conn = None + else: + reflect_opts['autoload_with'] = bind + conn = bind.contextual_connect() + + if schema is not None: + reflect_opts['schema'] = schema + + available = util.OrderedSet(bind.engine.table_names(schema, + connection=conn)) + current = util.Set(self.tables.keys()) + + if only is None: + load = [name for name in available if name not in current] + elif callable(only): + load = [name for name in available + if name not in current and only(name, self)] + else: + missing = [name for name in only if name not in available] + if missing: + s = schema and (" schema '%s'" % schema) or '' + raise exceptions.InvalidRequestError( + 'Could not reflect: requested table(s) not available ' + 'in %s%s: (%s)' % (bind.engine.url, s, ', '.join(missing))) + load = [name for name in only if name not in current] + + for name in load: + Table(name, self, **reflect_opts) + + def append_ddl_listener(self, event, listener): + """Append a DDL event listener to this ``MetaData``. + + The ``listener`` callable will be triggered when this ``MetaData`` is + involved in DDL creates or drops, and will be invoked either before + all Table-related actions or after. + + Arguments are: - def _get_parent(self): - return None + event + One of ``MetaData.ddl_events``; 'before-create', 'after-create', + 'before-drop' or 'after-drop'. + listener + A callable, invoked with three positional arguments: + + event + The event currently being handled + schema_item + The ``MetaData`` object being operated upon + bind + The ``Connection`` bueing used for DDL execution. + + Listeners are added to the MetaData's ``ddl_listeners`` attribute. + + Note: MetaData listeners are invoked even when ``Tables`` are created + in isolation. This may change in a future release. I.e.:: + + # triggers all MetaData and Table listeners: + metadata.create_all() + + # triggers MetaData listeners too: + some.table.create() + """ + + if event not in self.ddl_events: + raise LookupError(event) + self.ddl_listeners[event].append(listener) def create_all(self, bind=None, tables=None, checkfirst=True): """Create all tables stored in this metadata. - This will conditionally create tables depending on if they do - not yet exist in the database. + Conditional by default, will not attempt to recreate tables already + present in the target database. bind - A ``Connectable`` used to access the database; if None, uses - the existing bind on this ``MetaData``, if any. + A ``Connectable`` used to access the database; if None, uses the + existing bind on this ``MetaData``, if any. tables - Optional list of tables, which is a subset of the total + Optional list of ``Table`` objects, which is a subset of the total tables in the ``MetaData`` (others are ignored). + + checkfirst + Defaults to True, don't issue CREATEs for tables already present + in the target database. """ if bind is None: - bind = self._get_engine(raiseerr=True) + bind = _bind_or_error(self) + for listener in self.ddl_listeners['before-create']: + listener('before-create', self, bind) bind.create(self, checkfirst=checkfirst, tables=tables) + for listener in self.ddl_listeners['after-create']: + listener('after-create', self, bind) def drop_all(self, bind=None, tables=None, checkfirst=True): """Drop all tables stored in this metadata. - This will conditionally drop tables depending on if they - currently exist in the database. + Conditional by default, will not attempt to drop tables not present in + the target database. bind A ``Connectable`` used to access the database; if None, uses the existing bind on this ``MetaData``, if any. - + tables - Optional list of tables, which is a subset of the total - tables in the ``MetaData`` (others are ignored). + Optional list of ``Table`` objects, which is a subset of the + total tables in the ``MetaData`` (others are ignored). + + checkfirst + Defaults to True, don't issue CREATEs for tables already present + in the target database. """ if bind is None: - bind = self._get_engine(raiseerr=True) + bind = _bind_or_error(self) + for listener in self.ddl_listeners['before-drop']: + listener('before-drop', self, bind) bind.drop(self, checkfirst=checkfirst, tables=tables) + for listener in self.ddl_listeners['after-drop']: + listener('after-drop', self, bind) - def _derived_metadata(self): - return self +class ThreadLocalMetaData(MetaData): + """A MetaData variant that presents a different ``bind`` in every thread. - def _get_engine(self, raiseerr=False): - if not self.is_bound(): - if raiseerr: - raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") - else: - return None - return self._bind + Makes the ``bind`` property of the MetaData a thread-local value, allowing + this collection of tables to be bound to different ``Engine`` + implementations or connections in each thread. -class ThreadLocalMetaData(MetaData): - """Build upon ``MetaData`` to provide the capability to bind to -multiple ``Engine`` implementations on a dynamically alterable, -thread-local basis. + The ThreadLocalMetaData starts off bound to None in each thread. Binds + must be made explicitly by assigning to the ``bind`` property or using + ``connect()``. You can also re-bind dynamically multiple times per + thread, just like a regular ``MetaData``. + + Use this type of MetaData when your tables are present in more than one + database and you need to address them simultanesouly. """ __visit_name__ = 'metadata' - def __init__(self, **kwargs): + def __init__(self): + """Construct a ThreadLocalMetaData.""" + self.context = util.ThreadLocal() self.__engines = {} - super(ThreadLocalMetaData, self).__init__(**kwargs) + super(ThreadLocalMetaData, self).__init__() - def connect(self, engine_or_url, **kwargs): - from sqlalchemy.engine.url import URL - if isinstance(engine_or_url, (basestring, URL)): + # @deprecated + def connect(self, bind, **kwargs): + """Bind to an Engine in the caller's thread. + + Use ``metadata.bind=`` or ``metadata.bind=``. + + bind + A string, ``URL``, ``Engine`` or ``Connection`` instance. If a + string or ``URL``, will be passed to ``create_engine()`` along with + ``\**kwargs`` to produce the engine which to connect to. Otherwise + connects directly to the given ``Engine``. + """ + + global URL + if URL is None: + from sqlalchemy.engine.url import URL + + if isinstance(bind, (basestring, URL)): + try: + engine = self.__engines[bind] + except KeyError: + from sqlalchemy import create_engine + engine = create_engine(bind, **kwargs) + bind = engine + self._bind_to(bind) + connect = util.deprecated()(connect) + + def bind(self): + """The bound Engine or Connection for this thread. + + This property may be assigned an Engine or Connection, or assigned a + string or URL to automatically create a basic Engine for this bind + with ``create_engine()``.""" + + return getattr(self.context, '_engine', None) + + def _bind_to(self, bind): + """Bind to a Connectable in the caller's thread.""" + + global URL + if URL is None: + from sqlalchemy.engine.url import URL + + if isinstance(bind, (basestring, URL)): try: - self.context._engine = self.__engines[engine_or_url] + self.context._engine = self.__engines[bind] except KeyError: - e = sqlalchemy.create_engine(engine_or_url, **kwargs) - self.__engines[engine_or_url] = e + from sqlalchemy import create_engine + e = create_engine(bind) + self.__engines[bind] = e self.context._engine = e else: # TODO: this is squirrely. we shouldnt have to hold onto engines # in a case like this - if not self.__engines.has_key(engine_or_url): - self.__engines[engine_or_url] = engine_or_url - self.context._engine = engine_or_url + if bind not in self.__engines: + self.__engines[bind] = bind + self.context._engine = bind + + bind = property(bind, _bind_to) def is_bound(self): - return hasattr(self.context, '_engine') and self.context._engine is not None + """True if there is a bind for this thread.""" + return (hasattr(self.context, '_engine') and + self.context._engine is not None) def dispose(self): - """Dispose all ``Engines`` to which this ``ThreadLocalMetaData`` has been connected.""" + """Dispose all bound engines, in all thread contexts.""" for e in self.__engines.values(): - e.dispose() + if hasattr(e, 'dispose'): + e.dispose() + +class SchemaVisitor(visitors.ClauseVisitor): + """Define the visiting for ``SchemaItem`` objects.""" + + __traverse_options__ = {'schema_visitor':True} + + +class DDL(object): + """A literal DDL statement. + + Specifies literal SQL DDL to be executed by the database. DDL objects can + be attached to ``Tables`` or ``MetaData`` instances, conditionally + executing SQL as part of the DDL lifecycle of those schema items. Basic + templating support allows a single DDL instance to handle repetitive tasks + for multiple tables. + + Examples:: + + tbl = Table('users', metadata, Column('uid', Integer)) # ... + DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl) + + spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb') + spow.execute_at('after-create', tbl) + + drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') + connection.execute(drop_spow) + """ + + def __init__(self, statement, on=None, context=None, bind=None): + """Create a DDL statement. + + statement + A string or unicode string to be executed. Statements will be + processed with Python's string formatting operator. See the + ``context`` argument and the ``execute_at`` method. + + A literal '%' in a statement must be escaped as '%%'. + + SQL bind parameters are not available in DDL statements. + + on + Optional filtering criteria. May be a string or a callable + predicate. If a string, it will be compared to the name of the + executing database dialect:: + + DDL('something', on='postgres') + + If a callable, it will be invoked with three positional arguments: + + event + The name of the event that has triggered this DDL, such as + 'after-create' Will be None if the DDL is executed explicitly. + + schema_item + A SchemaItem instance, such as ``Table`` or ``MetaData``. May be + None if the DDL is executed explicitly. + + connection + The ``Connection`` being used for DDL execution + + If the callable returns a true value, the DDL statement will be + executed. + + context + Optional dictionary, defaults to None. These values will be + available for use in string substitutions on the DDL statement. + + bind + Optional. A ``Connectable``, used by default when ``execute()`` + is invoked without a bind argument. + """ + + if not isinstance(statement, basestring): + raise exceptions.ArgumentError( + "Expected a string or unicode SQL statement, got '%r'" % + statement) + if (on is not None and + (not isinstance(on, basestring) and not callable(on))): + raise exceptions.ArgumentError( + "Expected the name of a database dialect or a callable for " + "'on' criteria, got type '%s'." % type(on).__name__) + + self.statement = statement + self.on = on + self.context = context or {} + self._bind = bind + + def execute(self, bind=None, schema_item=None): + """Execute this DDL immediately. + + Executes the DDL statement in isolation using the supplied + ``Connectable`` or ``Connectable`` assigned to the ``.bind`` property, + if not supplied. If the DDL has a conditional ``on`` criteria, it + will be invoked with None as the event. + + bind + Optional, an ``Engine`` or ``Connection``. If not supplied, a + valid ``Connectable`` must be present in the ``.bind`` property. + + schema_item + Optional, defaults to None. Will be passed to the ``on`` callable + criteria, if any, and may provide string expansion data for the + statement. See ``execute_at`` for more information. + """ - def _get_engine(self, raiseerr=False): - if hasattr(self.context, '_engine'): - return self.context._engine + if bind is None: + bind = _bind_or_error(self) + # no SQL bind params are supported + if self._should_execute(None, schema_item, bind): + executable = expression.text(self._expand(schema_item, bind)) + return bind.execute(executable) else: - if raiseerr: - raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") - else: - return None - bind = property(_get_engine, connect) + bind.engine.logger.info("DDL execution skipped, criteria not met.") + def execute_at(self, event, schema_item): + """Link execution of this DDL to the DDL lifecycle of a SchemaItem. -class SchemaVisitor(sql.ClauseVisitor): - """Define the visiting for ``SchemaItem`` objects.""" + Links this ``DDL`` to a ``Table`` or ``MetaData`` instance, executing + it when that schema item is created or dropped. The DDL statement + will be executed using the same Connection and transactional context + as the Table create/drop itself. The ``.bind`` property of this + statement is ignored. - __traverse_options__ = {'schema_visitor':True} + event + One of the events defined in the schema item's ``.ddl_events``; + e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop' + + schema_item + A Table or MetaData instance + + When operating on Table events, the following additional ``statement`` + string substitions are available:: + + %(table)s - the Table name, with any required quoting applied + %(schema)s - the schema name, with any required quoting applied + %(fullname)s - the Table name including schema, quoted if needed + + The DDL's ``context``, if any, will be combined with the standard + substutions noted above. Keys present in the context will override + the standard substitutions. + + A DDL instance can be linked to any number of schema items. The + statement subsitution support allows for DDL instances to be used in a + template fashion. + + ``execute_at`` builds on the ``append_ddl_listener`` interface of + MetaDta and Table objects. + + Caveat: Creating or dropping a Table in isolation will also trigger + any DDL set to ``execute_at`` that Table's MetaData. This may change + in a future release. + """ + + if not hasattr(schema_item, 'ddl_listeners'): + raise exceptions.ArgumentError( + "%s does not support DDL events" % type(schema_item).__name__) + if event not in schema_item.ddl_events: + raise exceptions.ArgumentError( + "Unknown event, expected one of (%s), got '%r'" % + (', '.join(schema_item.ddl_events), event)) + schema_item.ddl_listeners[event].append(self) + return self + + def bind(self): + """An Engine or Connection to which this DDL is bound. + + This property may be assigned an ``Engine`` or ``Connection``, or + assigned a string or URL to automatically create a basic ``Engine`` + for this bind with ``create_engine()``. + """ + return self._bind + + def _bind_to(self, bind): + """Bind this MetaData to an Engine, Connection, string or URL.""" + + global URL + if URL is None: + from sqlalchemy.engine.url import URL + + if isinstance(bind, (basestring, URL)): + from sqlalchemy import create_engine + self._bind = create_engine(bind) + else: + self._bind = bind + bind = property(bind, _bind_to) + + def __call__(self, event, schema_item, bind): + """Execute the DDL as a ddl_listener.""" + + if self._should_execute(event, schema_item, bind): + statement = expression.text(self._expand(schema_item, bind)) + return bind.execute(statement) + + def _expand(self, schema_item, bind): + return self.statement % self._prepare_context(schema_item, bind) + + def _should_execute(self, event, schema_item, bind): + if self.on is None: + return True + elif isinstance(self.on, basestring): + return self.on == bind.engine.name + else: + return self.on(event, schema_item, bind) + + def _prepare_context(self, schema_item, bind): + # table events can substitute table and schema name + if isinstance(schema_item, Table): + context = self.context.copy() + + preparer = bind.dialect.identifier_preparer + path = preparer.format_table_seq(schema_item) + if len(path) == 1: + table, schema = path[0], '' + else: + table, schema = path[-1], path[0] + + context.setdefault('table', table) + context.setdefault('schema', schema) + context.setdefault('fullname', preparer.format_table(schema_item)) + return context + else: + return self.context + + def __repr__(self): + return '<%s@%s; %s>' % ( + type(self).__name__, id(self), + ', '.join([repr(self.statement)] + + ['%s=%r' % (key, getattr(self, key)) + for key in ('on', 'context') + if getattr(self, key)])) + + +def _bind_or_error(schemaitem): + bind = schemaitem.bind + if not bind: + name = schemaitem.__class__.__name__ + label = getattr(schemaitem, 'fullname', + getattr(schemaitem, 'name', None)) + if label: + item = '%s %r' % (name, label) + else: + item = name + if isinstance(schemaitem, (MetaData, DDL)): + bindable = "the %s's .bind" % name + else: + bindable = "this %s's .metadata.bind" % name + + msg = ('The %s is not bound to an Engine or Connection. ' + 'Execution can not proceed without a database to execute ' + 'against. Either execute with an explicit connection or ' + 'assign %s to enable implicit execution.') % (item, bindable) + raise exceptions.UnboundExecutionError(msg) + return bind diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py deleted file mode 100644 index 01588e92da..0000000000 --- a/lib/sqlalchemy/sql.py +++ /dev/null @@ -1,3406 +0,0 @@ -# sql.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Define the base components of SQL expression trees. - -All components are derived from a common base class [sqlalchemy.sql#ClauseElement]. -Common behaviors are organized based on class hierarchies, in some cases -via mixins. - -All object construction from this package occurs via functions which in some -cases will construct composite ``ClauseElement`` structures together, and -in other cases simply return a single ``ClauseElement`` constructed directly. -The function interface affords a more "DSL-ish" feel to constructing SQL expressions -and also allows future class reorganizations. - -Even though classes are not constructed directly from the outside, most -classes which have additional public methods are considered to be public (i.e. have no leading underscore). -Other classes which are "semi-public" are marked with a single leading -underscore; these classes usually have few or no public methods and -are less guaranteed to stay the same in future releases. - -""" - -from sqlalchemy import util, exceptions -from sqlalchemy import types as sqltypes -import re, operator - -__all__ = ['Alias', 'ClauseElement', 'ClauseParameters', - 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', - 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', - 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', - 'between', 'bindparam', 'case', 'cast', 'column', 'delete', - 'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', - 'insert', 'intersect', 'intersect_all', 'join', 'literal', - 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select', - 'subquery', 'table', 'text', 'union', 'union_all', 'update',] - -BIND_PARAMS = re.compile(r'(?_. The - "c" collection of the resulting ``Select`` object will use these - names as well for targeting column members. - - distinct=False - when ``True``, applies a ``DISTINCT`` qualifier to the - columns clause of the resulting statement. - - for_update=False - when ``True``, applies ``FOR UPDATE`` to the end of the - resulting statement. Certain database dialects also - support alternate values for this parameter, for example - mysql supports "read" which translates to ``LOCK IN SHARE MODE``, - and oracle supports "nowait" which translates to - ``FOR UPDATE NOWAIT``. - - bind=None - an ``Engine`` or ``Connection`` instance to which the resulting ``Select`` - object will be bound. The ``Select`` object will otherwise - automatically bind to whatever ``Connectable`` instances can be located - within its contained ``ClauseElement`` members. - - limit=None - a numerical value which usually compiles to a ``LIMIT`` expression - in the resulting select. Databases that don't support ``LIMIT`` - will attempt to provide similar functionality. - - offset=None - a numerical value which usually compiles to an ``OFFSET`` expression - in the resulting select. Databases that don't support ``OFFSET`` - will attempt to provide similar functionality. - - scalar=False - deprecated. use select(...).as_scalar() to create a "scalar column" - proxy for an existing Select object. - - correlate=True - indicates that this ``Select`` object should have its contained - ``FromClause`` elements "correlated" to an enclosing ``Select`` - object. This means that any ``ClauseElement`` instance within - the "froms" collection of this ``Select`` which is also present - in the "froms" collection of an enclosing select will not be - rendered in the ``FROM`` clause of this select statement. - - """ - scalar = kwargs.pop('scalar', False) - s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) - if scalar: - return s.as_scalar() - else: - return s - -def subquery(alias, *args, **kwargs): - """Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select]. - - name - alias name - - \*args, \**kwargs - all other arguments are delivered to the [sqlalchemy.sql#select()] function. - - """ - - return Select(*args, **kwargs).alias(alias) - -def insert(table, values = None, **kwargs): - """Return an [sqlalchemy.sql#Insert] clause element. - - Similar functionality is available via the ``insert()`` - method on [sqlalchemy.schema#Table]. - - table - The table to be inserted into. - - values - A dictionary which specifies the column specifications of the - ``INSERT``, and is optional. If left as None, the column - specifications are determined from the bind parameters used - during the compile phase of the ``INSERT`` statement. If the - bind parameters also are None during the compile phase, then the - column specifications will be generated from the full list of - table columns. - - If both `values` and compile-time bind parameters are present, the - compile-time bind parameters override the information specified - within `values` on a per-key basis. - - The keys within `values` can be either ``Column`` objects or their - string identifiers. Each key may reference one of: - - * a literal data value (i.e. string, number, etc.); - * a Column object; - * a SELECT statement. - - If a ``SELECT`` statement is specified which references this - ``INSERT`` statement's table, the statement will be correlated - against the ``INSERT`` statement. - """ - - return Insert(table, values, **kwargs) - -def update(table, whereclause = None, values = None, **kwargs): - """Return an [sqlalchemy.sql#Update] clause element. - - Similar functionality is available via the ``update()`` - method on [sqlalchemy.schema#Table]. - - table - The table to be updated. - - whereclause - A ``ClauseElement`` describing the ``WHERE`` condition of the - ``UPDATE`` statement. - - values - A dictionary which specifies the ``SET`` conditions of the - ``UPDATE``, and is optional. If left as None, the ``SET`` - conditions are determined from the bind parameters used during - the compile phase of the ``UPDATE`` statement. If the bind - parameters also are None during the compile phase, then the - ``SET`` conditions will be generated from the full list of table - columns. - - If both `values` and compile-time bind parameters are present, the - compile-time bind parameters override the information specified - within `values` on a per-key basis. - - The keys within `values` can be either ``Column`` objects or their - string identifiers. Each key may reference one of: - - * a literal data value (i.e. string, number, etc.); - * a Column object; - * a SELECT statement. - - If a ``SELECT`` statement is specified which references this - ``UPDATE`` statement's table, the statement will be correlated - against the ``UPDATE`` statement. - """ - - return Update(table, whereclause, values, **kwargs) - -def delete(table, whereclause = None, **kwargs): - """Return a [sqlalchemy.sql#Delete] clause element. - - Similar functionality is available via the ``delete()`` - method on [sqlalchemy.schema#Table]. - - table - The table to be updated. - - whereclause - A ``ClauseElement`` describing the ``WHERE`` condition of the - ``UPDATE`` statement. - - """ - - return Delete(table, whereclause, **kwargs) - -def and_(*clauses): - """Join a list of clauses together using the ``AND`` operator. - - The ``&`` operator is also overloaded on all [sqlalchemy.sql#_CompareMixin] - subclasses to produce the same result. - """ - if len(clauses) == 1: - return clauses[0] - return ClauseList(operator=operator.and_, *clauses) - -def or_(*clauses): - """Join a list of clauses together using the ``OR`` operator. - - The ``|`` operator is also overloaded on all [sqlalchemy.sql#_CompareMixin] - subclasses to produce the same result. - """ - - if len(clauses) == 1: - return clauses[0] - return ClauseList(operator=operator.or_, *clauses) - -def not_(clause): - """Return a negation of the given clause, i.e. ``NOT(clause)``. - - The ``~`` operator is also overloaded on all [sqlalchemy.sql#_CompareMixin] - subclasses to produce the same result. - """ - - return operator.inv(clause) - -def distinct(expr): - """return a ``DISTINCT`` clause.""" - - return _UnaryExpression(expr, operator="DISTINCT") - -def between(ctest, cleft, cright): - """Return a ``BETWEEN`` predicate clause. - - Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``. - - The ``between()`` method on all [sqlalchemy.sql#_CompareMixin] subclasses - provides similar functionality. - """ - - ctest = _literal_as_binds(ctest) - return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op) - - -def case(whens, value=None, else_=None): - """Produce a ``CASE`` statement. - - whens - A sequence of pairs to be translated into "when / then" clauses. - - value - Optional for simple case statements. - - else\_ - Optional as well, for case defaults. - - """ - - whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None) for (c,r) in whens] - if not else_ is None: - whenlist.append(ClauseList('ELSE', else_, operator=None)) - if len(whenlist): - type = list(whenlist[-1])[-1].type - else: - type = None - cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END']) - return cc - -def cast(clause, totype, **kwargs): - """Return a ``CAST`` function. - - Equivalent of SQL ``CAST(clause AS totype)``. - - Use with a [sqlalchemy.types#TypeEngine] subclass, i.e:: - - cast(table.c.unit_price * table.c.qty, Numeric(10,4)) - - or:: - - cast(table.c.timestamp, DATE) - """ - - return _Cast(clause, totype, **kwargs) - -def extract(field, expr): - """Return the clause ``extract(field FROM expr)``.""" - - expr = _BinaryExpression(text(field), expr, Operators.from_) - return func.extract(expr) - -def exists(*args, **kwargs): - """Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql#Select] object. - - The resulting [sqlalchemy.sql#_Exists] object can be executed by itself - or used as a subquery within an enclosing select. - - \*args, \**kwargs - all arguments are sent directly to the [sqlalchemy.sql#select()] function - to produce a ``SELECT`` statement. - - """ - - return _Exists(*args, **kwargs) - -def union(*selects, **kwargs): - """Return a ``UNION`` of multiple selectables. - - The returned object is an instance of [sqlalchemy.sql#CompoundSelect]. - - A similar ``union()`` method is available on all [sqlalchemy.sql#FromClause] - subclasses. - - \*selects - a list of [sqlalchemy.sql#Select] instances. - - \**kwargs - available keyword arguments are the same as those of [sqlalchemy.sql#select()]. - - """ - - return _compound_select('UNION', *selects, **kwargs) - -def union_all(*selects, **kwargs): - """Return a ``UNION ALL`` of multiple selectables. - - The returned object is an instance of [sqlalchemy.sql#CompoundSelect]. - - A similar ``union_all()`` method is available on all [sqlalchemy.sql#FromClause] - subclasses. - - \*selects - a list of [sqlalchemy.sql#Select] instances. - - \**kwargs - available keyword arguments are the same as those of [sqlalchemy.sql#select()]. - - """ - return _compound_select('UNION ALL', *selects, **kwargs) - -def except_(*selects, **kwargs): - """Return an ``EXCEPT`` of multiple selectables. - - The returned object is an instance of [sqlalchemy.sql#CompoundSelect]. - - \*selects - a list of [sqlalchemy.sql#Select] instances. - - \**kwargs - available keyword arguments are the same as those of [sqlalchemy.sql#select()]. - - """ - return _compound_select('EXCEPT', *selects, **kwargs) - -def except_all(*selects, **kwargs): - """Return an ``EXCEPT ALL`` of multiple selectables. - - The returned object is an instance of [sqlalchemy.sql#CompoundSelect]. - - \*selects - a list of [sqlalchemy.sql#Select] instances. - - \**kwargs - available keyword arguments are the same as those of [sqlalchemy.sql#select()]. - - """ - return _compound_select('EXCEPT ALL', *selects, **kwargs) - -def intersect(*selects, **kwargs): - """Return an ``INTERSECT`` of multiple selectables. - - The returned object is an instance of [sqlalchemy.sql#CompoundSelect]. - - \*selects - a list of [sqlalchemy.sql#Select] instances. - - \**kwargs - available keyword arguments are the same as those of [sqlalchemy.sql#select()]. - - """ - return _compound_select('INTERSECT', *selects, **kwargs) - -def intersect_all(*selects, **kwargs): - """Return an ``INTERSECT ALL`` of multiple selectables. - - The returned object is an instance of [sqlalchemy.sql#CompoundSelect]. - - \*selects - a list of [sqlalchemy.sql#Select] instances. - - \**kwargs - available keyword arguments are the same as those of [sqlalchemy.sql#select()]. - - """ - return _compound_select('INTERSECT ALL', *selects, **kwargs) - -def alias(selectable, alias=None): - """Return an [sqlalchemy.sql#Alias] object. - - An ``Alias`` represents any [sqlalchemy.sql#FromClause] with - an alternate name assigned within SQL, typically using the ``AS`` - clause when generated, e.g. ``SELECT * FROM table AS aliasname``. - - Similar functionality is available via the ``alias()`` method - available on all ``FromClause`` subclasses. - - selectable - any ``FromClause`` subclass, such as a table, select statement, etc.. - - alias - string name to be assigned as the alias. If ``None``, a random - name will be generated. - - """ - - return Alias(selectable, alias=alias) - - -def literal(value, type_=None): - """Return a literal clause, bound to a bind parameter. - - Literal clauses are created automatically when non- - ``ClauseElement`` objects (such as strings, ints, dates, etc.) are used in - a comparison operation with a [sqlalchemy.sql#_CompareMixin] - subclass, such as a ``Column`` object. Use this function - to force the generation of a literal clause, which will - be created as a [sqlalchemy.sql#_BindParamClause] with a bound - value. - - value - the value to be bound. can be any Python object supported by - the underlying DBAPI, or is translatable via the given type - argument. - - type\_ - an optional [sqlalchemy.types#TypeEngine] which will provide - bind-parameter translation for this literal. - - """ - - return _BindParamClause('literal', value, type_=type_, unique=True) - -def label(name, obj): - """Return a [sqlalchemy.sql#_Label] object for the given [sqlalchemy.sql#ColumnElement]. - - A label changes the name of an element in the columns clause - of a ``SELECT`` statement, typically via the ``AS`` SQL keyword. - - This functionality is more conveniently available via - the ``label()`` method on ``ColumnElement``. - - name - label name - - obj - a ``ColumnElement``. - - """ - - return _Label(name, obj) - -def column(text, type_=None): - """Return a textual column clause, as would be in the columns - clause of a ``SELECT`` statement. - - The object returned is an instance of [sqlalchemy.sql#_ColumnClause], - which represents the "syntactical" portion of the schema-level - [sqlalchemy.schema#Column] object. - - text - the name of the column. Quoting rules will be applied to - the clause like any other column name. For textual column - constructs that are not to be quoted, use the [sqlalchemy.sql#literal_column()] - function. - - type\_ - an optional [sqlalchemy.types#TypeEngine] object which will provide - result-set translation for this column. - - """ - - return _ColumnClause(text, type_=type_) - -def literal_column(text, type_=None): - """Return a textual column clause, as would be in the columns - clause of a ``SELECT`` statement. - - The object returned is an instance of [sqlalchemy.sql#_ColumnClause], - which represents the "syntactical" portion of the schema-level - [sqlalchemy.schema#Column] object. - - - text - the name of the column. Quoting rules will not be applied - to the column. For textual column - constructs that should be quoted like any other column - construct, use the [sqlalchemy.sql#column()] - function. - - type - an optional [sqlalchemy.types#TypeEngine] object which will provide - result-set translation for this column. - - """ - - return _ColumnClause(text, type_=type_, is_literal=True) - -def table(name, *columns): - """Return a [sqlalchemy.sql#Table] object. - - This is a primitive version of the [sqlalchemy.schema#Table] object, which - is a subclass of this object. - """ - - return TableClause(name, *columns) - -def bindparam(key, value=None, type_=None, shortname=None, unique=False): - """Create a bind parameter clause with the given key. - - value - a default value for this bind parameter. a bindparam with a value - is called a ``value-based bindparam``. - - shortname - an ``alias`` for this bind parameter. usually used to alias the ``key`` and - ``label`` of a column, i.e. ``somecolname`` and ``sometable_somecolname`` - - type - a sqlalchemy.types.TypeEngine object indicating the type of this bind param, will - invoke type-specific bind parameter processing - - unique - if True, bind params sharing the same name will have their underlying ``key`` modified - to a uniquely generated name. mostly useful with value-based bind params. - - """ - - if isinstance(key, _ColumnClause): - return _BindParamClause(key.name, value, type_=key.type, shortname=shortname, unique=unique) - else: - return _BindParamClause(key, value, type_=type_, shortname=shortname, unique=unique) - -def outparam(key, type_=None): - """create an 'OUT' parameter for usage in functions (stored procedures), for databases - whith support them. - - The ``outparam`` can be used like a regular function parameter. The "output" value will - be available from the [sqlalchemy.engine#ResultProxy] object via its ``out_parameters`` - attribute, which returns a dictionary containing the values. - """ - - return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True) - -def text(text, bind=None, *args, **kwargs): - """Create literal text to be inserted into a query. - - When constructing a query from a ``select()``, ``update()``, - ``insert()`` or ``delete()``, using plain strings for argument - values will usually result in text objects being created - automatically. Use this function when creating textual clauses - outside of other ``ClauseElement`` objects, or optionally wherever - plain text is to be used. - - text - The text of the SQL statement to be created. use ``:`` - to specify bind parameters; they will be compiled to their - engine-specific format. - - bind - An optional connection or engine to be used for this text query. - - bindparams - A list of ``bindparam()`` instances which can be used to define - the types and/or initial values for the bind parameters within - the textual statement; the keynames of the bindparams must match - those within the text of the statement. The types will be used - for pre-processing on bind values. - - typemap - A dictionary mapping the names of columns represented in the - ``SELECT`` clause of the textual statement to type objects, - which will be used to perform post-processing on columns within - the result set (for textual statements that produce result - sets). - - """ - - return _TextClause(text, bind=bind, *args, **kwargs) - -def null(): - """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement.""" - - return _Null() - -class _FunctionGenerator(object): - """Generate ``_Function`` objects based on getattr calls.""" - - def __init__(self, **opts): - self.__names = [] - self.opts = opts - - def __getattr__(self, name): - if name[-1] == '_': - name = name[0:-1] - f = _FunctionGenerator(**self.opts) - f.__names = list(self.__names) + [name] - return f - - def __call__(self, *c, **kwargs): - o = self.opts.copy() - o.update(kwargs) - return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) - -func = _FunctionGenerator() - -# TODO: use UnaryExpression for this instead ? -modifier = _FunctionGenerator(group=False) - - -def _compound_select(keyword, *selects, **kwargs): - return CompoundSelect(keyword, *selects, **kwargs) - -def _is_literal(element): - return not isinstance(element, ClauseElement) - -def _literal_as_text(element): - if isinstance(element, Operators): - return element.expression_element() - elif _is_literal(element): - return _TextClause(unicode(element)) - else: - return element - -def _literal_as_column(element): - if isinstance(element, Operators): - return element.clause_element() - elif _is_literal(element): - return literal_column(str(element)) - else: - return element - -def _literal_as_binds(element, name='literal', type_=None): - if isinstance(element, Operators): - return element.expression_element() - elif _is_literal(element): - if element is None: - return null() - else: - return _BindParamClause(name, element, shortname=name, type_=type_, unique=True) - else: - return element - -def _selectable(element): - if hasattr(element, '__selectable__'): - return element.__selectable__() - elif isinstance(element, Selectable): - return element - else: - raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) - -def is_column(col): - return isinstance(col, ColumnElement) - -class ClauseParameters(object): - """Represent a dictionary/iterator of bind parameter key names/values. - - Tracks the original [sqlalchemy.sql#_BindParamClause] objects as well as the - keys/position of each parameter, and can return parameters as a - dictionary or a list. Will process parameter values according to - the ``TypeEngine`` objects present in the ``_BindParamClause`` instances. - """ - - def __init__(self, dialect, positional=None): - super(ClauseParameters, self).__init__() - self.dialect = dialect - self.__binds = {} - self.positional = positional or [] - - def get_parameter(self, key): - return self.__binds[key] - - def set_parameter(self, bindparam, value, name): - self.__binds[name] = [bindparam, name, value] - - def get_original(self, key): - return self.__binds[key][2] - - def get_type(self, key): - return self.__binds[key][0].type - - def get_processed(self, key): - (bind, name, value) = self.__binds[key] - return bind.typeprocess(value, self.dialect) - - def keys(self): - return self.__binds.keys() - - def __iter__(self): - return iter(self.keys()) - - def __getitem__(self, key): - return self.get_processed(key) - - def __contains__(self, key): - return key in self.__binds - - def set_value(self, key, value): - self.__binds[key][2] = value - - def get_original_dict(self): - return dict([(name, value) for (b, name, value) in self.__binds.values()]) - - def get_raw_list(self): - return [self.get_processed(key) for key in self.positional] - - def get_raw_dict(self, encode_keys=False): - if encode_keys: - return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()]) - else: - return dict([(key, self.get_processed(key)) for key in self.keys()]) - - def __repr__(self): - return self.__class__.__name__ + ":" + repr(self.get_original_dict()) - -class ClauseVisitor(object): - """A class that knows how to traverse and visit - ``ClauseElements``. - - Calls visit_XXX() methods dynamically generated for each particualr - ``ClauseElement`` subclass encountered. Traversal of a - hierarchy of ``ClauseElements`` is achieved via the - ``traverse()`` method, which is passed the lead - ``ClauseElement``. - - By default, ``ClauseVisitor`` traverses all elements - fully. Options can be specified at the class level via the - ``__traverse_options__`` dictionary which will be passed - to the ``get_children()`` method of each ``ClauseElement``; - these options can indicate modifications to the set of - elements returned, such as to not return column collections - (column_collections=False) or to return Schema-level items - (schema_visitor=True). - - ``ClauseVisitor`` also supports a simultaneous copy-and-traverse - operation, which will produce a copy of a given ``ClauseElement`` - structure while at the same time allowing ``ClauseVisitor`` subclasses - to modify the new structure in-place. - - """ - __traverse_options__ = {} - - def traverse_single(self, obj, **kwargs): - meth = getattr(self, "visit_%s" % obj.__visit_name__, None) - if meth: - return meth(obj, **kwargs) - - def traverse(self, obj, stop_on=None, clone=False): - if clone: - obj = obj._clone() - - v = self - visitors = [] - while v is not None: - visitors.append(v) - v = getattr(v, '_next', None) - - def _trav(obj): - if stop_on is not None and obj in stop_on: - return - if clone: - obj._copy_internals() - for c in obj.get_children(**self.__traverse_options__): - _trav(c) - - for v in visitors: - meth = getattr(v, "visit_%s" % obj.__visit_name__, None) - if meth: - meth(obj) - _trav(obj) - return obj - - def chain(self, visitor): - """'chain' an additional ClauseVisitor onto this ClauseVisitor. - - the chained visitor will receive all visit events after this one.""" - tail = self - while getattr(tail, '_next', None) is not None: - tail = tail._next - tail._next = visitor - return self - -class NoColumnVisitor(ClauseVisitor): - """a ClauseVisitor that will not traverse the exported Column - collections on Table, Alias, Select, and CompoundSelect objects - (i.e. their 'columns' or 'c' attribute). - - this is useful because most traversals don't need those columns, or - in the case of ANSICompiler it traverses them explicitly; so - skipping their traversal here greatly cuts down on method call overhead. - """ - - __traverse_options__ = {'column_collections':False} - - -class _FigureVisitName(type): - def __init__(cls, clsname, bases, dict): - if not '__visit_name__' in cls.__dict__: - m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname) - x = m.group(1) - x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x) - cls.__visit_name__ = x.lower() - super(_FigureVisitName, cls).__init__(clsname, bases, dict) - -class ClauseElement(object): - """Base class for elements of a programmatically constructed SQL - expression. - """ - __metaclass__ = _FigureVisitName - - def _clone(self): - """create a shallow copy of this ClauseElement. - - This method may be used by a generative API. - Its also used as part of the "deep" copy afforded - by a traversal that combines the _copy_internals() - method.""" - c = self.__class__.__new__(self.__class__) - c.__dict__ = self.__dict__.copy() - return c - - def _get_from_objects(self, **modifiers): - """Return objects represented in this ``ClauseElement`` that - should be added to the ``FROM`` list of a query, when this - ``ClauseElement`` is placed in the column clause of a - ``Select`` statement. - """ - - raise NotImplementedError(repr(self)) - - def _hide_froms(self, **modifiers): - """Return a list of ``FROM`` clause elements which this - ``ClauseElement`` replaces. - """ - - return [] - - def compare(self, other): - """Compare this ClauseElement to the given ClauseElement. - - Subclasses should override the default behavior, which is a - straight identity comparison. - """ - - return self is other - - def _copy_internals(self): - """reassign internal elements to be clones of themselves. - - called during a copy-and-traverse operation on newly - shallow-copied elements to create a deep copy.""" - - pass - - def get_children(self, **kwargs): - """return immediate child elements of this ``ClauseElement``. - - this is used for visit traversal. - - \**kwargs may contain flags that change the collection - that is returned, for example to return a subset of items - in order to cut down on larger traversals, or to return - child items from a different context (such as schema-level - collections instead of clause-level).""" - return [] - - def self_group(self, against=None): - return self - - def supports_execution(self): - """Return True if this clause element represents a complete - executable statement. - """ - - return False - - def _find_engine(self): - """Default strategy for locating an engine within the clause element. - - Relies upon a local engine property, or looks in the *from* - objects which ultimately have to contain Tables or - TableClauses. - """ - - try: - if self._bind is not None: - return self._bind - except AttributeError: - pass - for f in self._get_from_objects(): - if f is self: - continue - engine = f.bind - if engine is not None: - return engine - else: - return None - - bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""") - - def execute(self, *multiparams, **params): - """Compile and execute this ``ClauseElement``.""" - - if len(multiparams): - compile_params = multiparams[0] - else: - compile_params = params - return self.compile(bind=self.bind, parameters=compile_params).execute(*multiparams, **params) - - def scalar(self, *multiparams, **params): - """Compile and execute this ``ClauseElement``, returning the - result's scalar representation. - """ - - return self.execute(*multiparams, **params).scalar() - - def compile(self, bind=None, parameters=None, compiler=None, dialect=None): - """Compile this SQL expression. - - Uses the given ``Compiler``, or the given ``AbstractDialect`` - or ``Engine`` to create a ``Compiler``. If no `compiler` - arguments are given, tries to use the underlying ``Engine`` this - ``ClauseElement`` is bound to to create a ``Compiler``, if any. - - Finally, if there is no bound ``Engine``, uses an - ``ANSIDialect`` to create a default ``Compiler``. - - `parameters` is a dictionary representing the default bind - parameters to be used with the statement. If `parameters` is - a list, it is assumed to be a list of dictionaries and the - first dictionary in the list is used with which to compile - against. - - The bind parameters can in some cases determine the output of - the compilation, such as for ``UPDATE`` and ``INSERT`` - statements the bind parameters that are present determine the - ``SET`` and ``VALUES`` clause of those statements. - """ - - if isinstance(parameters, (list, tuple)): - parameters = parameters[0] - - if compiler is None: - if dialect is not None: - compiler = dialect.compiler(self, parameters) - elif bind is not None: - compiler = bind.compiler(self, parameters) - elif self.bind is not None: - compiler = self.bind.compiler(self, parameters) - - if compiler is None: - import sqlalchemy.ansisql as ansisql - compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters) - compiler.compile() - return compiler - - def __str__(self): - return unicode(self.compile()).encode('ascii', 'backslashreplace') - - def __and__(self, other): - return and_(self, other) - - def __or__(self, other): - return or_(self, other) - - def __invert__(self): - return self._negate() - - def _negate(self): - if hasattr(self, 'negation_clause'): - return self.negation_clause - else: - return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None) - - -class Operators(object): - def from_(): - raise NotImplementedError() - from_ = staticmethod(from_) - - def as_(): - raise NotImplementedError() - as_ = staticmethod(as_) - - def exists(): - raise NotImplementedError() - exists = staticmethod(exists) - - def is_(): - raise NotImplementedError() - is_ = staticmethod(is_) - - def isnot(): - raise NotImplementedError() - isnot = staticmethod(isnot) - - def __and__(self, other): - return self.operate(operator.and_, other) - - def __or__(self, other): - return self.operate(operator.or_, other) - - def __invert__(self): - return self.operate(operator.inv) - - def clause_element(self): - raise NotImplementedError() - - def operate(self, op, *other, **kwargs): - raise NotImplementedError() - - def reverse_operate(self, op, *other, **kwargs): - raise NotImplementedError() - -class ColumnOperators(Operators): - """defines comparison and math operations""" - - def like_op(a, b): - return a.like(b) - like_op = staticmethod(like_op) - - def notlike_op(a, b): - raise NotImplementedError() - notlike_op = staticmethod(notlike_op) - - def ilike_op(a, b): - return a.ilike(b) - ilike_op = staticmethod(ilike_op) - - def notilike_op(a, b): - raise NotImplementedError() - notilike_op = staticmethod(notilike_op) - - def between_op(a, b): - return a.between(b) - between_op = staticmethod(between_op) - - def in_op(a, b): - return a.in_(*b) - in_op = staticmethod(in_op) - - def notin_op(a, b): - raise NotImplementedError() - notin_op = staticmethod(notin_op) - - def startswith_op(a, b): - return a.startswith(b) - startswith_op = staticmethod(startswith_op) - - def endswith_op(a, b): - return a.endswith(b) - endswith_op = staticmethod(endswith_op) - - def comma_op(a, b): - raise NotImplementedError() - comma_op = staticmethod(comma_op) - - def concat_op(a, b): - return a.concat(b) - concat_op = staticmethod(concat_op) - - def __lt__(self, other): - return self.operate(operator.lt, other) - - def __le__(self, other): - return self.operate(operator.le, other) - - def __eq__(self, other): - return self.operate(operator.eq, other) - - def __ne__(self, other): - return self.operate(operator.ne, other) - - def __gt__(self, other): - return self.operate(operator.gt, other) - - def __ge__(self, other): - return self.operate(operator.ge, other) - - def concat(self, other): - return self.operate(ColumnOperators.concat_op, other) - - def like(self, other): - return self.operate(ColumnOperators.like_op, other) - - def in_(self, *other): - return self.operate(ColumnOperators.in_op, other) - - def startswith(self, other): - return self.operate(ColumnOperators.startswith_op, other) - - def endswith(self, other): - return self.operate(ColumnOperators.endswith_op, other) - - def __radd__(self, other): - return self.reverse_operate(operator.add, other) - - def __rsub__(self, other): - return self.reverse_operate(operator.sub, other) - - def __rmul__(self, other): - return self.reverse_operate(operator.mul, other) - - def __rdiv__(self, other): - return self.reverse_operate(operator.div, other) - - def between(self, cleft, cright): - return self.operate(Operators.between_op, (cleft, cright)) - - def __add__(self, other): - return self.operate(operator.add, other) - - def __sub__(self, other): - return self.operate(operator.sub, other) - - def __mul__(self, other): - return self.operate(operator.mul, other) - - def __div__(self, other): - return self.operate(operator.div, other) - - def __mod__(self, other): - return self.operate(operator.mod, other) - - def __truediv__(self, other): - return self.operate(operator.truediv, other) - -# precedence ordering for common operators. if an operator is not present in this list, -# it will be parenthesized when grouped against other operators -_smallest = object() -_largest = object() - -PRECEDENCE = { - Operators.from_:15, - operator.mul:7, - operator.div:7, - operator.mod:7, - operator.add:6, - operator.sub:6, - ColumnOperators.concat_op:6, - ColumnOperators.ilike_op:5, - ColumnOperators.notilike_op:5, - ColumnOperators.like_op:5, - ColumnOperators.notlike_op:5, - ColumnOperators.in_op:5, - ColumnOperators.notin_op:5, - Operators.is_:5, - Operators.isnot:5, - operator.eq:5, - operator.ne:5, - operator.gt:5, - operator.lt:5, - operator.ge:5, - operator.le:5, - ColumnOperators.between_op:5, - operator.inv:4, - operator.and_:3, - operator.or_:2, - ColumnOperators.comma_op:-1, - Operators.as_:-1, - Operators.exists:0, - _smallest: -1000, - _largest: 1000 -} - -class _CompareMixin(ColumnOperators): - """Defines comparison and math operations for ``ClauseElement`` instances.""" - - def __compare(self, op, obj, negate=None): - if obj is None or isinstance(obj, _Null): - if op == operator.eq: - return _BinaryExpression(self.expression_element(), null(), Operators.is_, negate=Operators.isnot) - elif op == operator.ne: - return _BinaryExpression(self.expression_element(), null(), Operators.isnot, negate=Operators.is_) - else: - raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") - else: - obj = self._check_literal(obj) - - - return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate) - - def __operate(self, op, obj): - obj = self._check_literal(obj) - - type_ = self._compare_type(obj) - - # TODO: generalize operator overloading like this out into the types module - if op == operator.add and isinstance(type_, (sqltypes.Concatenable)): - op = ColumnOperators.concat_op - - return _BinaryExpression(self.expression_element(), obj, op, type_=type_) - - operators = { - operator.add : (__operate,), - operator.mul : (__operate,), - operator.sub : (__operate,), - operator.div : (__operate,), - operator.mod : (__operate,), - operator.truediv : (__operate,), - operator.lt : (__compare, operator.ge), - operator.le : (__compare, operator.gt), - operator.ne : (__compare, operator.eq), - operator.gt : (__compare, operator.le), - operator.ge : (__compare, operator.lt), - operator.eq : (__compare, operator.ne), - ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op), - } - - def operate(self, op, other): - o = _CompareMixin.operators[op] - return o[0](self, op, other, *o[1:]) - - def reverse_operate(self, op, other): - return self._bind_param(other).operate(op, self) - - def in_(self, *other): - return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other) - - def _in_impl(self, op, negate_op, *other): - if len(other) == 0: - return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1'))) - elif len(other) == 1: - o = other[0] - if _is_literal(o) or isinstance( o, _CompareMixin): - return self.__eq__( o) #single item -> == - else: - assert isinstance(o, Selectable) - return self.__compare( op, o, negate=negate_op) #single selectable - - args = [] - for o in other: - if not _is_literal(o): - if not isinstance( o, _CompareMixin): - raise exceptions.InvalidRequestError( "in() function accepts either non-selectable values, or a single selectable: "+repr(o) ) - else: - o = self._bind_param(o) - args.append(o) - return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op) - - def startswith(self, other): - """produce the clause ``LIKE '%'``""" - - perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String) - return self.__compare(ColumnOperators.like_op, other + perc) - - def endswith(self, other): - """produce the clause ``LIKE '%'``""" - - if isinstance(other,(str,unicode)): po = '%' + other - else: - po = literal('%', type_=sqltypes.String) + other - po.type = sqltypes.to_instance(sqltypes.String) #force! - return self.__compare(ColumnOperators.like_op, po) - - def label(self, name): - """produce a column label, i.e. `` AS ``""" - return _Label(name, self, self.type) - - def distinct(self): - """produce a DISTINCT clause, i.e. ``DISTINCT ``""" - return _UnaryExpression(self, operator="DISTINCT") - - def between(self, cleft, cright): - """produce a BETWEEN clause, i.e. `` BETWEEN AND ``""" - - return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), ColumnOperators.between_op) - - def op(self, operator): - """produce a generic operator function. - - e.g. - - somecolumn.op("*")(5) - - produces - - somecolumn * 5 - - operator - a string which will be output as the infix operator - between this ``ClauseElement`` and the expression - passed to the generated function. - - """ - return lambda other: self.__operate(operator, other) - - def _bind_param(self, obj): - return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True) - - def _check_literal(self, other): - if isinstance(other, Operators): - return other.expression_element() - elif _is_literal(other): - return self._bind_param(other) - else: - return other - - def clause_element(self): - """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``.""" - return self - - def expression_element(self): - """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" - - return self - - def _compare_type(self, obj): - """Allow subclasses to override the type used in constructing - ``_BinaryExpression`` objects. - - Default return value is the type of the given object. - """ - - return obj.type - -class Selectable(ClauseElement): - """Represent a column list-holding object. - - this is the common base class of [sqlalchemy.sql#ColumnElement] - and [sqlalchemy.sql#FromClause]. The reason ``ColumnElement`` - is marked as a "list-holding" object is so that it can be treated - similarly to ``FromClause`` in column-selection scenarios; it - contains a list of columns consisting of itself. - - """ - - columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""") - - def select(self, whereclauses = None, **params): - return select([self], whereclauses, **params) - - -class ColumnElement(Selectable, _CompareMixin): - """Represent an element that is useable within the - "column clause" portion of a ``SELECT`` statement. - - This includes columns associated with tables, aliases, - and subqueries, expressions, function calls, SQL keywords - such as ``NULL``, literals, etc. ``ColumnElement`` is the - ultimate base class for all such elements. - - ``ColumnElement`` supports the ability to be a *proxy* element, - which indicates that the ``ColumnElement`` may be associated with - a ``Selectable`` which was derived from another ``Selectable``. - An example of a "derived" ``Selectable`` is an ``Alias`` of - a ``Table``. - - a ``ColumnElement``, by subclassing the ``_CompareMixin`` mixin - class, provides the ability to generate new ``ClauseElement`` - objects using Python expressions. See the ``_CompareMixin`` - docstring for more details. - """ - - primary_key = property(lambda self:getattr(self, '_primary_key', False), - doc=\ - """Primary key flag. Indicates if this ``Column`` represents part or - whole of a primary key for its parent table. - """) - foreign_keys = property(lambda self:getattr(self, '_foreign_keys', []), - doc=\ - """Foreign key accessor. References a list of ``ForeignKey`` objects - which each represent a foreign key placed on this column's ultimate - ancestor. - """) - columns = property(lambda self:[self], - doc=\ - """Columns accessor which returns ``self``, to provide compatibility - with ``Selectable`` objects. - """) - - def _one_fkey(self): - if len(self._foreign_keys): - return list(self._foreign_keys)[0] - else: - return None - - foreign_key = property(_one_fkey) - - def _get_orig_set(self): - try: - return self.__orig_set - except AttributeError: - self.__orig_set = util.Set([self]) - return self.__orig_set - - def _set_orig_set(self, s): - if len(s) == 0: - s.add(self) - self.__orig_set = s - - orig_set = property(_get_orig_set, _set_orig_set, - doc=\ - """A Set containing TableClause-bound, non-proxied ColumnElements - for which this ColumnElement is a proxy. In all cases except - for a column proxied from a Union (i.e. CompoundSelect), this - set will be just one element. - """) - - def shares_lineage(self, othercolumn): - """Return True if the given ``ColumnElement`` has a common - ancestor to this ``ColumnElement``. - """ - - for c in self.orig_set: - if c in othercolumn.orig_set: - return True - else: - return False - - def _make_proxy(self, selectable, name=None): - """Create a new ``ColumnElement`` representing this - ``ColumnElement`` as it appears in the select list of a - descending selectable. - - The default implementation returns a ``_ColumnClause`` if a - name is given, else just returns self. - """ - - if name is not None: - co = _ColumnClause(name, selectable) - co.orig_set = self.orig_set - selectable.columns[name]= co - return co - else: - return self - -class ColumnCollection(util.OrderedProperties): - """An ordered dictionary that stores a list of ColumnElement - instances. - - Overrides the ``__eq__()`` method to produce SQL clauses between - sets of correlated columns. - """ - - def __init__(self, *cols): - super(ColumnCollection, self).__init__() - [self.add(c) for c in cols] - - def __str__(self): - return repr([str(c) for c in self]) - - def add(self, column): - """Add a column to this collection. - - The key attribute of the column will be used as the hash key - for this dictionary. - """ - - # Allow an aliased column to replace an unaliased column of the - # same name. - if self.has_key(column.name): - other = self[column.name] - if other.name == other.key: - del self[other.name] - self[column.key] = column - - def remove(self, column): - del self[column.key] - - def extend(self, iter): - for c in iter: - self.add(c) - - def __eq__(self, other): - l = [] - for c in other: - for local in self: - if c.shares_lineage(local): - l.append(c==local) - return and_(*l) - - def __contains__(self, other): - if not isinstance(other, basestring): - raise exceptions.ArgumentError("__contains__ requires a string argument") - return self.has_key(other) - - def contains_column(self, col): - # have to use a Set here, because it will compare the identity - # of the column, not just using "==" for comparison which will always return a - # "True" value (i.e. a BinaryClause...) - return col in util.Set(self) - -class ColumnSet(util.OrderedSet): - def contains_column(self, col): - return col in self - - def extend(self, cols): - for col in cols: - self.add(col) - - def __add__(self, other): - return list(self) + list(other) - - def __eq__(self, other): - l = [] - for c in other: - for local in self: - if c.shares_lineage(local): - l.append(c==local) - return and_(*l) - -class FromClause(Selectable): - """Represent an element that can be used within the ``FROM`` - clause of a ``SELECT`` statement. - """ - - __visit_name__ = 'fromclause' - - def __init__(self, name=None): - self.name = name - - def _get_from_objects(self, **modifiers): - # this could also be [self], at the moment it doesnt matter to the Select object - return [] - - def default_order_by(self): - return [self.oid_column] - - def count(self, whereclause=None, **params): - if len(self.primary_key): - col = list(self.primary_key)[0] - else: - col = list(self.columns)[0] - return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) - - def join(self, right, *args, **kwargs): - return Join(self, right, *args, **kwargs) - - def outerjoin(self, right, *args, **kwargs): - return Join(self, right, isouter=True, *args, **kwargs) - - def alias(self, name=None): - return Alias(self, name) - - def named_with_column(self): - """True if the name of this FromClause may be prepended to a - column in a generated SQL statement. - """ - - return False - - def _locate_oid_column(self): - """Subclasses should override this to return an appropriate OID column.""" - - return None - - def _get_oid_column(self): - if not hasattr(self, '_oid_column'): - self._oid_column = self._locate_oid_column() - return self._oid_column - - def _get_all_embedded_columns(self): - ret = [] - class FindCols(ClauseVisitor): - def visit_column(self, col): - ret.append(col) - FindCols().traverse(self) - return ret - - def is_derived_from(self, fromclause): - """return True if this FromClause is 'derived' from the given FromClause. - - An example would be an Alias of a Table is derived from that Table.""" - - return False - - def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False): - """Given a ``ColumnElement``, return the exported - ``ColumnElement`` object from this ``Selectable`` which - corresponds to that original ``Column`` via a common - anscestor column. - - column - the target ``ColumnElement`` to be matched - - raiseerr - if True, raise an error if the given ``ColumnElement`` - could not be matched. if False, non-matches will - return None. - - keys_ok - if the ``ColumnElement`` cannot be matched, attempt to - match based on the string "key" property of the column - alone. This makes the search much more liberal. - - require_embedded - only return corresponding columns for the given - ``ColumnElement``, if the given ``ColumnElement`` is - actually present within a sub-element of this - ``FromClause``. Normally the column will match if - it merely shares a common anscestor with one of - the exported columns of this ``FromClause``. - """ - - if self.c.contains_column(column): - return column - - if require_embedded and column not in util.Set(self._get_all_embedded_columns()): - if not raiseerr: - return None - else: - raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table)) - for c in column.orig_set: - try: - return self.original_columns[c] - except KeyError: - pass - else: - if keys_ok: - try: - return self.c[column.name] - except KeyError: - pass - if not raiseerr: - return None - else: - raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.name)) - - def _get_exported_attribute(self, name): - try: - return getattr(self, name) - except AttributeError: - self._export_columns() - return getattr(self, name) - - def _clone_from_clause(self): - # delete all the "generated" collections of columns for a newly cloned FromClause, - # so that they will be re-derived from the item. - # this is because FromClause subclasses, when cloned, need to reestablish new "proxied" - # columns that are linked to the new item - for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'): - if hasattr(self, attr): - delattr(self, attr) - - columns = property(lambda s:s._get_exported_attribute('_columns')) - c = property(lambda s:s._get_exported_attribute('_columns')) - primary_key = property(lambda s:s._get_exported_attribute('_primary_key')) - foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys')) - original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'), doc=\ - """A dictionary mapping an original Table-bound - column to a proxied column in this FromClause. - """) - oid_column = property(_get_oid_column) - - def _export_columns(self, columns=None): - """Initialize column collections. - - The collections include the primary key, foreign keys, list of - all columns, as well as the *_orig_cols* collection which is a - dictionary used to match Table-bound columns to proxied - columns in this ``FromClause``. The columns in each - collection are *proxied* from the columns returned by the - _exportable_columns method, where a *proxied* column maintains - most or all of the properties of its original column, except - its parent ``Selectable`` is this ``FromClause``. - """ - - if hasattr(self, '_columns') and columns is None: - # TODO: put a mutex here ? this is a key place for threading probs - return - self._columns = ColumnCollection() - self._primary_key = ColumnSet() - self._foreign_keys = util.Set() - self._orig_cols = {} - - if columns is None: - columns = self._flatten_exportable_columns() - for co in columns: - cp = self._proxy_column(co) - for ci in cp.orig_set: - cx = self._orig_cols.get(ci) - # TODO: the '=' thing here relates to the order of columns as they are placed in the - # "columns" collection of a CompositeSelect, illustrated in test/sql/selectable.SelectableTest.testunion - # make this relationship less brittle - if cx is None or cp._distance <= cx._distance: - self._orig_cols[ci] = cp - if self.oid_column is not None: - for ci in self.oid_column.orig_set: - self._orig_cols[ci] = self.oid_column - - def _flatten_exportable_columns(self): - """return the list of ColumnElements represented within this FromClause's _exportable_columns""" - export = self._exportable_columns() - for column in export: - # TODO: is this conditional needed ? - if isinstance(column, Selectable): - s = column - else: - continue - for co in s.columns: - yield co - - def _exportable_columns(self): - return [] - - def _proxy_column(self, column): - return column._make_proxy(self) - -class _BindParamClause(ClauseElement, _CompareMixin): - """Represent a bind parameter. - - Public constructor is the ``bindparam()`` function. - """ - - __visit_name__ = 'bindparam' - - def __init__(self, key, value, shortname=None, type_=None, unique=False, isoutparam=False): - """Construct a _BindParamClause. - - key - the key for this bind param. Will be used in the generated - SQL statement for dialects that use named parameters. This - value may be modified when part of a compilation operation, - if other ``_BindParamClause`` objects exist with the same - key, or if its length is too long and truncation is - required. - - value - Initial value for this bind param. This value may be - overridden by the dictionary of parameters sent to statement - compilation/execution. - - shortname - Defaults to the key, a *short name* that will also identify - this bind parameter, similar to an alias. the bind - parameter keys sent to a statement compilation or compiled - execution may match either the key or the shortname of the - corresponding ``_BindParamClause`` objects. - - type\_ - A ``TypeEngine`` object that will be used to pre-process the - value corresponding to this ``_BindParamClause`` at - execution time. - - unique - if True, the key name of this BindParamClause will be - modified if another ``_BindParamClause`` of the same - name already has been located within the containing - ``ClauseElement``. - - isoutparam - if True, the parameter should be treated like a stored procedure "OUT" - parameter. - """ - - self.key = key or "{ANON %d param}" % id(self) - self.value = value - self.shortname = shortname or key - self.unique = unique - self.isoutparam = isoutparam - type_ = sqltypes.to_instance(type_) - if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map: - self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)]) - else: - self.type = type_ - - # TODO: move to types module, obviously - type_map = { - str : sqltypes.String, - unicode : sqltypes.Unicode, - int : sqltypes.Integer, - float : sqltypes.Numeric - } - - def _get_from_objects(self, **modifiers): - return [] - - def typeprocess(self, value, dialect): - return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) - - def compare(self, other): - """Compare this ``_BindParamClause`` to the given clause. - - Since ``compare()`` is meant to compare statement syntax, this - method returns True if the two ``_BindParamClauses`` have just - the same type. - """ - - return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ - - def __repr__(self): - return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type)) - -class _TypeClause(ClauseElement): - """Handle a type keyword in a SQL statement. - - Used by the ``Case`` statement. - """ - - __visit_name__ = 'typeclause' - - def __init__(self, type): - self.type = type - - def _get_from_objects(self, **modifiers): - return [] - -class _TextClause(ClauseElement): - """Represent a literal SQL text fragment. - - Public constructor is the ``text()`` function. - """ - - __visit_name__ = 'textclause' - - def __init__(self, text = "", bind=None, bindparams=None, typemap=None): - self._bind = bind - self.bindparams = {} - self.typemap = typemap - if typemap is not None: - for key in typemap.keys(): - typemap[key] = sqltypes.to_instance(typemap[key]) - - def repl(m): - self.bindparams[m.group(1)] = bindparam(m.group(1)) - return ":%s" % m.group(1) - - # scan the string and search for bind parameter names, add them - # to the list of bindparams - self.text = BIND_PARAMS.sub(repl, text) - if bindparams is not None: - for b in bindparams: - self.bindparams[b.key] = b - - def _get_type(self): - if self.typemap is not None and len(self.typemap) == 1: - return list(self.typemap)[0] - else: - return None - type = property(_get_type) - - columns = property(lambda s:[]) - - def _copy_internals(self): - self.bindparams = [b._clone() for b in self.bindparams] - - def get_children(self, **kwargs): - return self.bindparams.values() - - def _get_from_objects(self, **modifiers): - return [] - - def supports_execution(self): - return True - -class _Null(ColumnElement): - """Represent the NULL keyword in a SQL statement. - - Public constructor is the ``null()`` function. - """ - - def __init__(self): - self.type = sqltypes.NULLTYPE - - def _get_from_objects(self, **modifiers): - return [] - -class ClauseList(ClauseElement): - """Describe a list of clauses, separated by an operator. - - By default, is comma-separated, such as a column listing. - """ - __visit_name__ = 'clauselist' - - def __init__(self, *clauses, **kwargs): - self.clauses = [] - self.operator = kwargs.pop('operator', ColumnOperators.comma_op) - self.group = kwargs.pop('group', True) - self.group_contents = kwargs.pop('group_contents', True) - for c in clauses: - if c is None: - continue - self.append(c) - - def __iter__(self): - return iter(self.clauses) - def __len__(self): - return len(self.clauses) - - def append(self, clause): - # TODO: not sure if i like the 'group_contents' flag. need to define the difference between - # a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ? - if self.group_contents: - self.clauses.append(_literal_as_text(clause).self_group(against=self.operator)) - else: - self.clauses.append(_literal_as_text(clause)) - - def _copy_internals(self): - self.clauses = [clause._clone() for clause in self.clauses] - - def get_children(self, **kwargs): - return self.clauses - - def _get_from_objects(self, **modifiers): - f = [] - for c in self.clauses: - f += c._get_from_objects(**modifiers) - return f - - def self_group(self, against=None): - if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]): - return _Grouping(self) - else: - return self - - def compare(self, other): - """Compare this ``ClauseList`` to the given ``ClauseList``, - including a comparison of all the clause items. - """ - - if not isinstance(other, ClauseList) and len(self.clauses) == 1: - return self.clauses[0].compare(other) - elif isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses): - for i in range(0, len(self.clauses)): - if not self.clauses[i].compare(other.clauses[i]): - return False - else: - return self.operator == other.operator - else: - return False - -class _CalculatedClause(ColumnElement): - """Describe a calculated SQL expression that has a type, like ``CASE``. - - Extends ``ColumnElement`` to provide column-level comparison - operators. - """ - __visit_name__ = 'calculatedclause' - - def __init__(self, name, *clauses, **kwargs): - self.name = name - self.type = sqltypes.to_instance(kwargs.get('type_', None)) - self._bind = kwargs.get('bind', None) - self.group = kwargs.pop('group', True) - clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) - if self.group: - self.clause_expr = clauses.self_group() - else: - self.clause_expr = clauses - - key = property(lambda self:self.name or "_calc_") - - def _copy_internals(self): - self.clause_expr = self.clause_expr._clone() - - def clauses(self): - if isinstance(self.clause_expr, _Grouping): - return self.clause_expr.elem - else: - return self.clause_expr - clauses = property(clauses) - - def get_children(self, **kwargs): - return self.clause_expr, - - def _get_from_objects(self, **modifiers): - return self.clauses._get_from_objects(**modifiers) - - def _bind_param(self, obj): - return _BindParamClause(self.name, obj, type_=self.type, unique=True) - - def select(self): - return select([self]) - - def scalar(self): - return select([self]).execute().scalar() - - def execute(self): - return select([self]).execute() - - def _compare_type(self, obj): - return self.type - -class _Function(_CalculatedClause, FromClause): - """Describe a SQL function. - - Extends ``_CalculatedClause``, turn the *clauselist* into function - arguments, also adds a `packagenames` argument. - """ - - def __init__(self, name, *clauses, **kwargs): - self.packagenames = kwargs.get('packagenames', None) or [] - kwargs['operator'] = ColumnOperators.comma_op - _CalculatedClause.__init__(self, name, **kwargs) - for c in clauses: - self.append(c) - - key = property(lambda self:self.name) - - def _copy_internals(self): - _CalculatedClause._copy_internals(self) - self._clone_from_clause() - - def get_children(self, **kwargs): - return _CalculatedClause.get_children(self, **kwargs) - - def append(self, clause): - self.clauses.append(_literal_as_binds(clause, self.name)) - -class _Cast(ColumnElement): - - def __init__(self, clause, totype, **kwargs): - if not hasattr(clause, 'label'): - clause = literal(clause) - self.type = sqltypes.to_instance(totype) - self.clause = clause - self.typeclause = _TypeClause(self.type) - self._distance = 0 - - def _copy_internals(self): - self.clause = self.clause._clone() - self.typeclause = self.typeclause._clone() - - def get_children(self, **kwargs): - return self.clause, self.typeclause - - def _get_from_objects(self, **modifiers): - return self.clause._get_from_objects(**modifiers) - - def _make_proxy(self, selectable, name=None): - if name is not None: - co = _ColumnClause(name, selectable, type_=self.type) - co._distance = self._distance + 1 - co.orig_set = self.orig_set - selectable.columns[name]= co - return co - else: - return self - - -class _UnaryExpression(ColumnElement): - def __init__(self, element, operator=None, modifier=None, type_=None, negate=None): - self.operator = operator - self.modifier = modifier - - self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier) - self.type = sqltypes.to_instance(type_) - self.negate = negate - - def _get_from_objects(self, **modifiers): - return self.element._get_from_objects(**modifiers) - - def _copy_internals(self): - self.element = self.element._clone() - - def get_children(self, **kwargs): - return self.element, - - def compare(self, other): - """Compare this ``_UnaryExpression`` against the given ``ClauseElement``.""" - - return ( - isinstance(other, _UnaryExpression) and self.operator == other.operator and - self.modifier == other.modifier and - self.element.compare(other.element) - ) - - def _negate(self): - if self.negate is not None: - return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type) - else: - return super(_UnaryExpression, self)._negate() - - def self_group(self, against): - if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]): - return _Grouping(self) - else: - return self - - -class _BinaryExpression(ColumnElement): - """Represent an expression that is ``LEFT RIGHT``.""" - - def __init__(self, left, right, operator, type_=None, negate=None): - self.left = _literal_as_text(left).self_group(against=operator) - self.right = _literal_as_text(right).self_group(against=operator) - self.operator = operator - self.type = sqltypes.to_instance(type_) - self.negate = negate - - def _get_from_objects(self, **modifiers): - return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - - def _copy_internals(self): - self.left = self.left._clone() - self.right = self.right._clone() - - def get_children(self, **kwargs): - return self.left, self.right - - def compare(self, other): - """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" - - return ( - isinstance(other, _BinaryExpression) and self.operator == other.operator and - ( - self.left.compare(other.left) and self.right.compare(other.right) - or ( - self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and - self.left.compare(other.right) and self.right.compare(other.left) - ) - ) - ) - - def self_group(self, against=None): - # use small/large defaults for comparison so that unknown operators are always parenthesized - if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])): - return _Grouping(self) - else: - return self - - def _negate(self): - if self.negate is not None: - return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type) - else: - return super(_BinaryExpression, self)._negate() - -class _Exists(_UnaryExpression): - __visit_name__ = _UnaryExpression.__visit_name__ - - def __init__(self, *args, **kwargs): - kwargs['correlate'] = True - s = select(*args, **kwargs).self_group() - _UnaryExpression.__init__(self, s, operator=Operators.exists) - - def _hide_froms(self, **modifiers): - return self._get_from_objects(**modifiers) - -class Join(FromClause): - """represent a ``JOIN`` construct between two ``FromClause`` - elements. - - the public constructor function for ``Join`` is the module-level - ``join()`` function, as well as the ``join()`` method available - off all ``FromClause`` subclasses. - - """ - def __init__(self, left, right, onclause=None, isouter = False): - self.left = _selectable(left) - self.right = _selectable(right).self_group() - if onclause is None: - self.onclause = self._match_primaries(self.left, self.right) - else: - self.onclause = onclause - self.isouter = isouter - self.__folded_equivalents = None - self._init_primary_key() - - name = property(lambda s: "Join object on " + s.left.name + " " + s.right.name) - encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace')) - - def _init_primary_key(self): - pkcol = util.Set([c for c in self._flatten_exportable_columns() if c.primary_key]) - - equivs = {} - def add_equiv(a, b): - for x, y in ((a, b), (b, a)): - if x in equivs: - equivs[x].add(y) - else: - equivs[x] = util.Set([y]) - - class BinaryVisitor(ClauseVisitor): - def visit_binary(self, binary): - if binary.operator == operator.eq: - add_equiv(binary.left, binary.right) - BinaryVisitor().traverse(self.onclause) - - for col in pkcol: - for fk in col.foreign_keys: - if fk.column in pkcol: - add_equiv(col, fk.column) - - omit = util.Set() - for col in pkcol: - p = col - for c in equivs.get(col, util.Set()): - if p.references(c) or (c.primary_key and not p.primary_key): - omit.add(p) - p = c - - self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit]) - - primary_key = property(lambda s:s.__primary_key) - - def self_group(self, against=None): - return _Grouping(self) - - def _locate_oid_column(self): - return self.left.oid_column - - def _exportable_columns(self): - return [c for c in self.left.columns] + [c for c in self.right.columns] - - def _proxy_column(self, column): - self._columns[column._label] = column - for f in column.foreign_keys: - self._foreign_keys.add(f) - return column - - def _copy_internals(self): - self._clone_from_clause() - self.left = self.left._clone() - self.right = self.right._clone() - self.onclause = self.onclause._clone() - self.__folded_equivalents = None - self._init_primary_key() - - def get_children(self, **kwargs): - return self.left, self.right, self.onclause - - def _match_primaries(self, primary, secondary): - crit = [] - constraints = util.Set() - for fk in secondary.foreign_keys: - if fk.references(primary): - crit.append(primary.corresponding_column(fk.column) == fk.parent) - constraints.add(fk.constraint) - self.foreignkey = fk.parent - if primary is not secondary: - for fk in primary.foreign_keys: - if fk.references(secondary): - crit.append(secondary.corresponding_column(fk.column) == fk.parent) - constraints.add(fk.constraint) - self.foreignkey = fk.parent - if len(crit) == 0: - raise exceptions.ArgumentError("Can't find any foreign key relationships " - "between '%s' and '%s'" % (primary.name, secondary.name)) - elif len(constraints) > 1: - raise exceptions.ArgumentError("Can't determine join between '%s' and '%s'; " - "tables have more than one foreign key " - "constraint relationship between them. " - "Please specify the 'onclause' of this " - "join explicitly." % (primary.name, secondary.name)) - elif len(crit) == 1: - return (crit[0]) - else: - return and_(*crit) - - def _get_folded_equivalents(self, equivs=None): - if self.__folded_equivalents is not None: - return self.__folded_equivalents - if equivs is None: - equivs = util.Set() - class LocateEquivs(NoColumnVisitor): - def visit_binary(self, binary): - if binary.operator == operator.eq and binary.left.name == binary.right.name: - equivs.add(binary.right) - equivs.add(binary.left) - LocateEquivs().traverse(self.onclause) - collist = [] - if isinstance(self.left, Join): - left = self.left._get_folded_equivalents(equivs) - else: - left = list(self.left.columns) - if isinstance(self.right, Join): - right = self.right._get_folded_equivalents(equivs) - else: - right = list(self.right.columns) - used = util.Set() - for c in left + right: - if c in equivs: - if c.name not in used: - collist.append(c) - used.add(c.name) - else: - collist.append(c) - self.__folded_equivalents = collist - return self.__folded_equivalents - - folded_equivalents = property(_get_folded_equivalents, doc="Returns the column list of this Join with all equivalently-named, " - "equated columns folded into one column, where 'equated' means they are " - "equated to each other in the ON clause of this join.") - - def select(self, whereclause = None, fold_equivalents=False, **kwargs): - """Create a ``Select`` from this ``Join``. - - whereclause - the WHERE criterion that will be sent to the ``select()`` function - - fold_equivalents - based on the join criterion of this ``Join``, do not include repeat - column names in the column list of the resulting select, for columns that - are calculated to be "equivalent" based on the join criterion of this - ``Join``. this will recursively apply to any joins directly nested by - this one as well. - - \**kwargs - all other kwargs are sent to the underlying ``select()`` function. - See the ``select()`` module level function for details. - - """ - if fold_equivalents: - collist = self.folded_equivalents - else: - collist = [self.left, self.right] - - return select(collist, whereclause, from_obj=[self], **kwargs) - - bind = property(lambda s:s.left.bind or s.right.bind) - - def alias(self, name=None): - """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it. - - The ``Select`` is not correlating. - """ - - return self.select(use_labels=True, correlate=False).alias(name) - - def _hide_froms(self, **modifiers): - return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - - def _get_from_objects(self, **modifiers): - return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - -class Alias(FromClause): - """represent an alias, as typically applied to any - table or sub-select within a SQL statement using the - ``AS`` keyword (or without the keyword on certain databases - such as Oracle). - - this object is constructed from the ``alias()`` module level function - as well as the ``alias()`` method available on all ``FromClause`` - subclasses. - - """ - def __init__(self, selectable, alias=None): - baseselectable = selectable - while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable - self.original = baseselectable - self.selectable = selectable - if alias is None: - if self.original.named_with_column(): - alias = getattr(self.original, 'name', None) - alias = '{ANON %d %s}' % (id(self), alias or 'anon') - self.name = alias - self.encodedname = alias.encode('ascii', 'backslashreplace') - self.case_sensitive = getattr(baseselectable, "case_sensitive", True) - - def is_derived_from(self, fromclause): - x = self.selectable - while True: - if x is fromclause: - return True - if isinstance(x, Alias): - x = x.selectable - else: - break - return False - - def supports_execution(self): - return self.original.supports_execution() - - def _locate_oid_column(self): - if self.selectable.oid_column is not None: - return self.selectable.oid_column._make_proxy(self) - else: - return None - - def named_with_column(self): - return True - - def _exportable_columns(self): - #return self.selectable._exportable_columns() - return self.selectable.columns - - def _copy_internals(self): - self._clone_from_clause() - self.selectable = self.selectable._clone() - baseselectable = self.selectable - while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable - self.original = baseselectable - - def get_children(self, **kwargs): - for c in self.c: - yield c - yield self.selectable - - def _get_from_objects(self): - return [self] - - bind = property(lambda s: s.selectable.bind) - -class _ColumnElementAdapter(ColumnElement): - """adapts a ClauseElement which may or may not be a - ColumnElement subclass itself into an object which - acts like a ColumnElement. - """ - - def __init__(self, elem): - self.elem = elem - self.type = getattr(elem, 'type', None) - self.orig_set = getattr(elem, 'orig_set', util.Set()) - - key = property(lambda s: s.elem.key) - _label = property(lambda s: s.elem._label) - columns = c = property(lambda s:s.elem.columns) - - def _copy_internals(self): - self.elem = self.elem._clone() - - def get_children(self, **kwargs): - return self.elem, - - def _hide_froms(self, **modifiers): - return self.elem._hide_froms(**modifiers) - - def _get_from_objects(self, **modifiers): - return self.elem._get_from_objects(**modifiers) - - def __getattr__(self, attr): - return getattr(self.elem, attr) - -class _Grouping(_ColumnElementAdapter): - pass - -class _Label(ColumnElement): - """represent a label, as typically applied to any column-level element - using the ``AS`` sql keyword. - - this object is constructed from the ``label()`` module level function - as well as the ``label()`` method available on all ``ColumnElement`` - subclasses. - - """ - - def __init__(self, name, obj, type_=None): - while isinstance(obj, _Label): - obj = obj.obj - self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) - - self.obj = obj.self_group(against=Operators.as_) - self.case_sensitive = getattr(obj, "case_sensitive", True) - self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) - - key = property(lambda s: s.name) - _label = property(lambda s: s.name) - orig_set = property(lambda s:s.obj.orig_set) - - def expression_element(self): - return self.obj - - def _copy_internals(self): - self.obj = self.obj._clone() - - def get_children(self, **kwargs): - return self.obj, - - def _get_from_objects(self, **modifiers): - return self.obj._get_from_objects(**modifiers) - - def _hide_froms(self, **modifiers): - return self.obj._hide_froms(**modifiers) - - def _make_proxy(self, selectable, name = None): - if isinstance(self.obj, Selectable): - return self.obj._make_proxy(selectable, name=self.name) - else: - return column(self.name)._make_proxy(selectable=selectable) - -class _ColumnClause(ColumnElement): - """Represents a generic column expression from any textual string. - This includes columns associated with tables, aliases and select - statements, but also any arbitrary text. May or may not be bound - to an underlying ``Selectable``. ``_ColumnClause`` is usually - created publically via the ``column()`` function or the - ``column_literal()`` function. - - text - the text of the element. - - selectable - parent selectable. - - type - ``TypeEngine`` object which can associate this ``_ColumnClause`` - with a type. - - case_sensitive - defines whether identifier quoting rules will be applied to the - generated text of this ``_ColumnClause`` so that it is identified in - a case-sensitive manner. - - is_literal - if True, the ``_ColumnClause`` is assumed to be an exact expression - that will be delivered to the output with no quoting rules applied - regardless of case sensitive settings. the ``column_literal()`` function is - usually used to create such a ``_ColumnClause``. - - """ - - def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False): - self.key = self.name = text - self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name - self.table = selectable - self.type = sqltypes.to_instance(type_) - self._is_oid = _is_oid - self._distance = 0 - self.__label = None - self.case_sensitive = case_sensitive - self.is_literal = is_literal - - def _clone(self): - # ColumnClause is immutable - return self - - def _get_label(self): - """Generate a 'label' for this column. - - The label is a product of the parent table name and column - name, and is treated as a unique identifier of this ``Column`` - across all ``Tables`` and derived selectables for a particular - metadata collection. - """ - - # for a "literal" column, we've no idea what the text is - # therefore no 'label' can be automatically generated - if self.is_literal: - return None - if self.__label is None: - if self.table is not None and self.table.named_with_column(): - self.__label = self.table.name + "_" + self.name - counter = 1 - while self.table.c.has_key(self.__label): - self.__label = self.__label + "_%d" % counter - counter += 1 - else: - self.__label = self.name - return self.__label - - is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name) - - _label = property(_get_label) - - def label(self, name): - # if going off the "__label" property and its None, we have - # no label; return self - if name is None: - return self - else: - return super(_ColumnClause, self).label(name) - - def _get_from_objects(self, **modifiers): - if self.table is not None: - return [self.table] - else: - return [] - - def _bind_param(self, obj): - return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True) - - def _make_proxy(self, selectable, name = None): - # propigate the "is_literal" flag only if we are keeping our name, - # otherwise its considered to be a label - is_literal = self.is_literal and (name is None or name == self.name) - c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) - c.orig_set = self.orig_set - c._distance = self._distance + 1 - if not self._is_oid: - selectable.columns[c.name] = c - return c - - def _compare_type(self, obj): - return self.type - -class TableClause(FromClause): - """represents a "table" construct. - - Note that this represents tables only as another - syntactical construct within SQL expressions; it - does not provide schema-level functionality. - - """ - - def __init__(self, name, *columns): - super(TableClause, self).__init__(name) - self.name = self.fullname = name - self.encodedname = self.name.encode('ascii', 'backslashreplace') - self._oid_column = _ColumnClause('oid', self, _is_oid=True) - self._export_columns(columns) - - def _clone(self): - # TableClause is immutable - return self - - def named_with_column(self): - return True - - def append_column(self, c): - self._columns[c.name] = c - c.table = self - - def _locate_oid_column(self): - return self._oid_column - - def _proxy_column(self, c): - self.append_column(c) - return c - - def _orig_columns(self): - try: - return self._orig_cols - except AttributeError: - self._orig_cols= {} - for c in self.columns: - for ci in c.orig_set: - self._orig_cols[ci] = c - return self._orig_cols - - original_columns = property(_orig_columns) - - def get_children(self, column_collections=True, **kwargs): - if column_collections: - return [c for c in self.c] - else: - return [] - - def _exportable_columns(self): - raise NotImplementedError() - - def count(self, whereclause=None, **params): - if len(self.primary_key): - col = list(self.primary_key)[0] - else: - col = list(self.columns)[0] - return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) - - def join(self, right, *args, **kwargs): - return Join(self, right, *args, **kwargs) - - def outerjoin(self, right, *args, **kwargs): - return Join(self, right, isouter = True, *args, **kwargs) - - def alias(self, name=None): - return Alias(self, name) - - def select(self, whereclause = None, **params): - return select([self], whereclause, **params) - - def insert(self, values = None): - return insert(self, values=values) - - def update(self, whereclause = None, values = None): - return update(self, whereclause, values) - - def delete(self, whereclause = None): - return delete(self, whereclause) - - def _get_from_objects(self, **modifiers): - return [self] - - -class _SelectBaseMixin(object): - """Base class for ``Select`` and ``CompoundSelects``.""" - - def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None): - self.use_labels = use_labels - self.for_update = for_update - self._limit = limit - self._offset = offset - self._bind = bind - - self.append_order_by(*util.to_list(order_by, [])) - self.append_group_by(*util.to_list(group_by, [])) - - def as_scalar(self): - return _ScalarSelect(self) - - def label(self, name): - return self.as_scalar().label(name) - - def supports_execution(self): - return True - - def _generate(self): - s = self._clone() - s._clone_from_clause() - return s - - def limit(self, limit): - s = self._generate() - s._limit = limit - return s - - def offset(self, offset): - s = self._generate() - s._offset = offset - return s - - def order_by(self, *clauses): - s = self._generate() - s.append_order_by(*clauses) - return s - - def group_by(self, *clauses): - s = self._generate() - s.append_group_by(*clauses) - return s - - def append_order_by(self, *clauses): - if clauses == [None]: - self._order_by_clause = ClauseList() - else: - if getattr(self, '_order_by_clause', None): - clauses = list(self._order_by_clause) + list(clauses) - self._order_by_clause = ClauseList(*clauses) - - def append_group_by(self, *clauses): - if clauses == [None]: - self._group_by_clause = ClauseList() - else: - if getattr(self, '_group_by_clause', None): - clauses = list(self._group_by_clause) + list(clauses) - self._group_by_clause = ClauseList(*clauses) - - def select(self, whereclauses = None, **params): - return select([self], whereclauses, **params) - - def _get_from_objects(self, is_where=False, **modifiers): - if is_where: - return [] - else: - return [self] - -class _ScalarSelect(_Grouping): - __visit_name__ = 'grouping' - - def __init__(self, elem): - super(_ScalarSelect, self).__init__(elem) - self.type = list(elem.inner_columns)[0].type - - columns = property(lambda self:[self]) - - def self_group(self, **kwargs): - return self - - def _make_proxy(self, selectable, name): - return list(self.inner_columns)[0]._make_proxy(selectable, name) - - def _get_from_objects(self, **modifiers): - return [] - -class CompoundSelect(_SelectBaseMixin, FromClause): - def __init__(self, keyword, *selects, **kwargs): - self._should_correlate = kwargs.pop('correlate', False) - self.keyword = keyword - self.selects = [] - - # some DBs do not like ORDER BY in the inner queries of a UNION, etc. - for n, s in enumerate(selects): - if len(s._order_by_clause): - s = s.order_by(None) - # unions group from left to right, so don't group first select - if n: - self.selects.append(s.self_group(self)) - else: - self.selects.append(s) - - self._col_map = {} - - _SelectBaseMixin.__init__(self, **kwargs) - - name = property(lambda s:s.keyword + " statement") - - def self_group(self, against=None): - return _Grouping(self) - - def _locate_oid_column(self): - return self.selects[0].oid_column - - def _exportable_columns(self): - for s in self.selects: - for c in s.c: - yield c - - def _proxy_column(self, column): - if self.use_labels: - col = column._make_proxy(self, name=column._label) - else: - col = column._make_proxy(self) - try: - colset = self._col_map[col.name] - except KeyError: - colset = util.Set() - self._col_map[col.name] = colset - [colset.add(c) for c in col.orig_set] - col.orig_set = colset - return col - - def _copy_internals(self): - self._clone_from_clause() - self._col_map = {} - self.selects = [s._clone() for s in self.selects] - for attr in ('_order_by_clause', '_group_by_clause'): - if getattr(self, attr) is not None: - setattr(self, attr, getattr(self, attr)._clone()) - - def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.c) or []) + \ - [self._order_by_clause, self._group_by_clause] + list(self.selects) - - def _find_engine(self): - for s in self.selects: - e = s._find_engine() - if e: - return e - else: - return None - -class Select(_SelectBaseMixin, FromClause): - """Represent a ``SELECT`` statement, with appendable clauses, as - well as the ability to execute itself and return a result set. - - """ - - def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs): - """construct a Select object. - - The public constructor for Select is the [sqlalchemy.sql#select()] function; - see that function for argument descriptions. - """ - - self._should_correlate = correlate - self._distinct = distinct - - self._raw_columns = [] - self.__correlate = util.Set() - self._froms = util.OrderedSet() - self._whereclause = None - self._having = None - self._prefixes = [] - - if columns is not None: - for c in columns: - self.append_column(c) - - if from_obj is not None: - for f in from_obj: - self.append_from(f) - - if whereclause is not None: - self.append_whereclause(whereclause) - - if having is not None: - self.append_having(having) - - _SelectBaseMixin.__init__(self, **kwargs) - - def _get_display_froms(self, correlation_state=None): - """return the full list of 'from' clauses to be displayed. - - takes into account an optional 'correlation_state' - dictionary which contains information about this Select's - correlation to an enclosing select, which may cause some 'from' - clauses to not display in this Select's FROM clause. - this dictionary is generated during compile time by the - _calculate_correlations() method. - - """ - froms = util.OrderedSet() - hide_froms = util.Set() - - for col in self._raw_columns: - for f in col._hide_froms(): - hide_froms.add(f) - for f in col._get_from_objects(): - froms.add(f) - - if self._whereclause is not None: - for f in self._whereclause._get_from_objects(is_where=True): - froms.add(f) - - for elem in self._froms: - froms.add(elem) - for f in elem._get_from_objects(): - froms.add(f) - - for elem in froms: - for f in elem._hide_froms(): - hide_froms.add(f) - - froms = froms.difference(hide_froms) - - if len(froms) > 1: - corr = self.__correlate - if correlation_state is not None: - corr = correlation_state[self].get('correlate', util.Set()).union(corr) - f = froms.difference(corr) - if len(f) == 0: - raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) - return f - else: - return froms - - froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") - - def locate_all_froms(self): - froms = util.Set() - for col in self._raw_columns: - for f in col._get_from_objects(): - froms.add(f) - - if self._whereclause is not None: - for f in self._whereclause._get_from_objects(is_where=True): - froms.add(f) - - for elem in self._froms: - froms.add(elem) - for f in elem._get_from_objects(): - froms.add(f) - return froms - - def _calculate_correlations(self, correlation_state): - """generate a 'correlation_state' dictionary used by the _get_display_froms() method. - - The dictionary is passed in initially empty, or already - containing the state information added by an enclosing - Select construct. The method will traverse through all - embedded Select statements and add information about their - position and "from" objects to the dictionary. Those Select - statements will later consult the 'correlation_state' dictionary - when their list of 'FROM' clauses are generated using their - _get_display_froms() method. - """ - - if self not in correlation_state: - correlation_state[self] = {} - - display_froms = self._get_display_froms(correlation_state) - - class CorrelatedVisitor(NoColumnVisitor): - def __init__(self, is_where=False, is_column=False, is_from=False): - self.is_where = is_where - self.is_column = is_column - self.is_from = is_from - - def visit_compound_select(self, cs): - self.visit_select(cs) - - def visit_select(s, select): - if select not in correlation_state: - correlation_state[select] = {} - - if select is self: - return - - select_state = correlation_state[select] - if s.is_from: - select_state['is_selected_from'] = True - if s.is_where: - select_state['is_where'] = True - select_state['is_subquery'] = True - - if select._should_correlate: - corr = select_state.setdefault('correlate', util.Set()) - # not crazy about this part. need to be clearer on what elements in the - # subquery correspond to elements in the enclosing query. - for f in display_froms: - corr.add(f) - for f2 in f._get_from_objects(): - corr.add(f2) - - col_vis = CorrelatedVisitor(is_column=True) - where_vis = CorrelatedVisitor(is_where=True) - from_vis = CorrelatedVisitor(is_from=True) - - for col in self._raw_columns: - col_vis.traverse(col) - for f in col._get_from_objects(): - if f is not self: - from_vis.traverse(f) - - for col in list(self._order_by_clause) + list(self._group_by_clause): - col_vis.traverse(col) - - if self._whereclause is not None: - where_vis.traverse(self._whereclause) - for f in self._whereclause._get_from_objects(is_where=True): - if f is not self: - from_vis.traverse(f) - - for elem in self._froms: - from_vis.traverse(elem) - - def _get_inner_columns(self): - for c in self._raw_columns: - if isinstance(c, Selectable): - for co in c.columns: - yield co - else: - yield c - - inner_columns = property(_get_inner_columns) - - def _copy_internals(self): - self._clone_from_clause() - self._raw_columns = [c._clone() for c in self._raw_columns] - self._recorrelate_froms([(f, f._clone()) for f in self._froms]) - for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): - if getattr(self, attr) is not None: - setattr(self, attr, getattr(self, attr)._clone()) - - def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.columns) or []) + \ - list(self._froms) + \ - [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] - - def _recorrelate_froms(self, froms): - newcorrelate = util.Set() - newfroms = util.Set() - oldfroms = util.Set(self._froms) - for old, new in froms: - if old in self.__correlate: - newcorrelate.add(new) - self.__correlate.remove(old) - if old in oldfroms: - newfroms.add(new) - oldfroms.remove(old) - self.__correlate = self.__correlate.union(newcorrelate) - self._froms = [f for f in oldfroms.union(newfroms)] - - def column(self, column): - s = self._generate() - s.append_column(column) - return s - - def where(self, whereclause): - s = self._generate() - s.append_whereclause(whereclause) - return s - - def having(self, having): - s = self._generate() - s.append_having(having) - return s - - def distinct(self): - s = self._generate() - s.distinct = True - return s - - def prefix_with(self, clause): - s = self._generate() - s.append_prefix(clause) - return s - - def select_from(self, fromclause): - s = self._generate() - s.append_from(fromclause) - return s - - def __dont_correlate(self): - s = self._generate() - s._should_correlate = False - return s - - def correlate(self, fromclause): - s = self._generate() - s._should_correlate=False - if fromclause is None: - s.__correlate = util.Set() - else: - s.append_correlation(fromclause) - return s - - def append_correlation(self, fromclause): - self.__correlate.add(fromclause) - - def append_column(self, column): - column = _literal_as_column(column) - - if isinstance(column, _ScalarSelect): - column = column.self_group(against=ColumnOperators.comma_op) - - self._raw_columns.append(column) - - def append_prefix(self, clause): - clause = _literal_as_text(clause) - self._prefixes.append(clause) - - def append_whereclause(self, whereclause): - if self._whereclause is not None: - self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) - else: - self._whereclause = _literal_as_text(whereclause) - - def append_having(self, having): - if self._having is not None: - self._having = and_(self._having, _literal_as_text(having)) - else: - self._having = _literal_as_text(having) - - def append_from(self, fromclause): - if _is_literal(fromclause): - fromclause = FromClause(fromclause) - self._froms.add(fromclause) - - def _exportable_columns(self): - return [c for c in self._raw_columns if isinstance(c, Selectable)] - - def _proxy_column(self, column): - if self.use_labels: - return column._make_proxy(self, name=column._label) - else: - return column._make_proxy(self) - - def self_group(self, against=None): - if isinstance(against, CompoundSelect): - return self - return _Grouping(self) - - def _locate_oid_column(self): - for f in self.locate_all_froms(): - if f is self: - # we might be in our own _froms list if a column with us as the parent is attached, - # which includes textual columns. - continue - oid = f.oid_column - if oid is not None: - return oid - else: - return None - - def union(self, other, **kwargs): - return union(self, other, **kwargs) - - def union_all(self, other, **kwargs): - return union_all(self, other, **kwargs) - - def _find_engine(self): - """Try to return a Engine, either explicitly set in this - object, or searched within the from clauses for one. - """ - - if self._bind is not None: - return self._bind - for f in self._froms: - if f is self: - continue - e = f.bind - if e is not None: - self._bind = e - return e - # look through the columns (largely synomous with looking - # through the FROMs except in the case of _CalculatedClause/_Function) - for cc in self._exportable_columns(): - for c in cc.columns: - if getattr(c, 'table', None) is self: - continue - e = c.bind - if e is not None: - self._bind = e - return e - return None - -class _UpdateBase(ClauseElement): - """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" - - def supports_execution(self): - return True - - def _calculate_correlations(self, correlate_state): - class SelectCorrelator(NoColumnVisitor): - def visit_select(s, select): - if select._should_correlate: - select_state = correlate_state.setdefault(select, {}) - corr = select_state.setdefault('correlate', util.Set()) - corr.add(self.table) - - vis = SelectCorrelator() - - if self._whereclause is not None: - vis.traverse(self._whereclause) - - if getattr(self, 'parameters', None) is not None: - for key, value in self.parameters.items(): - if isinstance(value, ClauseElement): - vis.traverse(value) - - def _process_colparams(self, parameters): - """Receive the *values* of an ``INSERT`` or ``UPDATE`` - statement and construct appropriate bind parameters. - """ - - if parameters is None: - return None - - if isinstance(parameters, (list, tuple)): - pp = {} - i = 0 - for c in self.table.c: - pp[c.key] = parameters[i] - i +=1 - parameters = pp - - for key in parameters.keys(): - value = parameters[key] - if isinstance(value, ClauseElement): - parameters[key] = value.self_group() - elif _is_literal(value): - if _is_literal(key): - col = self.table.c[key] - else: - col = key - try: - parameters[key] = bindparam(col, value, unique=True) - except KeyError: - del parameters[key] - return parameters - - def _find_engine(self): - return self.table.bind - -class Insert(_UpdateBase): - def __init__(self, table, values=None): - self.table = table - self.select = None - self.parameters = self._process_colparams(values) - - def get_children(self, **kwargs): - if self.select is not None: - return self.select, - else: - return () - -class Update(_UpdateBase): - def __init__(self, table, whereclause, values=None): - self.table = table - self._whereclause = whereclause - self.parameters = self._process_colparams(values) - - def get_children(self, **kwargs): - if self._whereclause is not None: - return self._whereclause, - else: - return () - -class Delete(_UpdateBase): - def __init__(self, table, whereclause): - self.table = table - self._whereclause = whereclause - - def get_children(self, **kwargs): - if self._whereclause is not None: - return self._whereclause, - else: - return () - -class _IdentifiedClause(ClauseElement): - def __init__(self, ident): - self.ident = ident - def supports_execution(self): - return True - -class SavepointClause(_IdentifiedClause): - pass - -class RollbackToSavepointClause(_IdentifiedClause): - pass - -class ReleaseSavepointClause(_IdentifiedClause): - pass diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py new file mode 100644 index 0000000000..c966f396a2 --- /dev/null +++ b/lib/sqlalchemy/sql/__init__.py @@ -0,0 +1,2 @@ +from sqlalchemy.sql.expression import * +from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py new file mode 100644 index 0000000000..1fe9ef0622 --- /dev/null +++ b/lib/sqlalchemy/sql/compiler.py @@ -0,0 +1,1107 @@ +# compiler.py +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Base SQL and DDL compiler implementations. + +Provides the [sqlalchemy.sql.compiler#DefaultCompiler] class, which is +responsible for generating all SQL query strings, as well as +[sqlalchemy.sql.compiler#SchemaGenerator] and [sqlalchemy.sql.compiler#SchemaDropper] +which issue CREATE and DROP DDL for tables, sequences, and indexes. + +The elements in this module are used by public-facing constructs like +[sqlalchemy.sql.expression#ClauseElement] and [sqlalchemy.engine#Engine]. +While dialect authors will want to be familiar with this module for the purpose of +creating database-specific compilers and schema generators, the module +is otherwise internal to SQLAlchemy. +""" + +import string, re, itertools +from sqlalchemy import schema, engine, util, exceptions +from sqlalchemy.sql import operators, functions +from sqlalchemy.sql import expression as sql + +RESERVED_WORDS = util.Set([ + 'all', 'analyse', 'analyze', 'and', 'any', 'array', + 'as', 'asc', 'asymmetric', 'authorization', 'between', + 'binary', 'both', 'case', 'cast', 'check', 'collate', + 'column', 'constraint', 'create', 'cross', 'current_date', + 'current_role', 'current_time', 'current_timestamp', + 'current_user', 'default', 'deferrable', 'desc', + 'distinct', 'do', 'else', 'end', 'except', 'false', + 'for', 'foreign', 'freeze', 'from', 'full', 'grant', + 'group', 'having', 'ilike', 'in', 'initially', 'inner', + 'intersect', 'into', 'is', 'isnull', 'join', 'leading', + 'left', 'like', 'limit', 'localtime', 'localtimestamp', + 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', + 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', + 'placing', 'primary', 'references', 'right', 'select', + 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', + 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', + 'using', 'verbose', 'when', 'where']) + +LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) +ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]') + +BIND_PARAMS = re.compile(r'(?', + operators.ge : '>=', + operators.eq : '=', + operators.distinct_op : 'DISTINCT', + operators.concat_op : '||', + operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), + operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), + operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), + operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), + operators.between_op : 'BETWEEN', + operators.in_op : 'IN', + operators.notin_op : 'NOT IN', + operators.comma_op : ', ', + operators.desc_op : 'DESC', + operators.asc_op : 'ASC', + operators.from_ : 'FROM', + operators.as_ : 'AS', + operators.exists : 'EXISTS', + operators.is_ : 'IS', + operators.isnot : 'IS NOT', + operators.collate : 'COLLATE', +} + +FUNCTIONS = { + functions.coalesce : 'coalesce%(expr)s', + functions.current_date: 'CURRENT_DATE', + functions.current_time: 'CURRENT_TIME', + functions.current_timestamp: 'CURRENT_TIMESTAMP', + functions.current_user: 'CURRENT_USER', + functions.localtime: 'LOCALTIME', + functions.localtimestamp: 'LOCALTIMESTAMP', + functions.random: 'random%(expr)s', + functions.sysdate: 'sysdate', + functions.session_user :'SESSION_USER', + functions.user: 'USER' +} + +class DefaultCompiler(engine.Compiled): + """Default implementation of Compiled. + + Compiles ClauseElements into SQL strings. Uses a similar visit + paradigm as visitors.ClauseVisitor but implements its own traversal. + """ + + __traverse_options__ = {'column_collections':False, 'entry':True} + + operators = OPERATORS + functions = FUNCTIONS + + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): + """Construct a new ``DefaultCompiler`` object. + + dialect + Dialect to be used + + statement + ClauseElement to be compiled + + column_keys + a list of column names to be compiled into an INSERT or UPDATE + statement. + """ + + super(DefaultCompiler, self).__init__(dialect, statement, column_keys, **kwargs) + + # if we are insert/update/delete. set to true when we visit an INSERT, UPDATE or DELETE + self.isdelete = self.isinsert = self.isupdate = False + + # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) + self.inline = inline or getattr(statement, 'inline', False) + + # a dictionary of bind parameter keys to _BindParamClause instances. + self.binds = {} + + # a dictionary of _BindParamClause instances to "compiled" names that are + # actually present in the generated SQL + self.bind_names = {} + + # a stack. what recursive compiler doesn't have a stack ? :) + self.stack = [] + + # relates label names in the final SQL to + # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. + # ResultProxy uses this for type processing and column targeting + self.result_map = {} + + # a dictionary of ClauseElement subclasses to counters, which are used to + # generate truncated identifier names or "anonymous" identifiers such as + # for aliases + self.generated_ids = {} + + # paramstyle from the dialect (comes from DB-API) + self.paramstyle = self.dialect.paramstyle + + # true if the paramstyle is positional + self.positional = self.dialect.positional + + self.bindtemplate = BIND_TEMPLATES[self.paramstyle] + + # a list of the compiled's bind parameter names, used to help + # formulate a positional argument list + self.positiontup = [] + + # an IdentifierPreparer that formats the quoting of identifiers + self.preparer = self.dialect.identifier_preparer + + def compile(self): + self.string = self.process(self.statement) + + def process(self, obj, stack=None, **kwargs): + if stack: + self.stack.append(stack) + try: + meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kwargs) + finally: + if stack: + self.stack.pop(-1) + + def is_subquery(self, select): + return self.stack and self.stack[-1].get('is_subquery') + + def construct_params(self, params=None): + """return a dictionary of bind parameter keys and values""" + + if params: + pd = {} + for bindparam, name in self.bind_names.iteritems(): + for paramname in (bindparam, bindparam.key, bindparam.shortname, name): + if paramname in params: + pd[name] = params[paramname] + break + else: + if callable(bindparam.value): + pd[name] = bindparam.value() + else: + pd[name] = bindparam.value + return pd + else: + pd = {} + for bindparam in self.bind_names: + if callable(bindparam.value): + pd[self.bind_names[bindparam]] = bindparam.value() + else: + pd[self.bind_names[bindparam]] = bindparam.value + return pd + + params = property(construct_params) + + def default_from(self): + """Called when a SELECT statement has no froms, and no FROM clause is to be appended. + + Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. + """ + + return "" + + def visit_grouping(self, grouping, **kwargs): + return "(" + self.process(grouping.elem) + ")" + + def visit_label(self, label, result_map=None): + labelname = self._truncated_identifier("colident", label.name) + + if result_map is not None: + result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type) + + return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) + + def visit_column(self, column, result_map=None, **kwargs): + + if column._is_oid: + name = self.dialect.oid_column_name(column) + if name is None: + if len(column.table.primary_key) != 0: + pk = list(column.table.primary_key)[0] + return self.visit_column(pk, result_map=result_map, **kwargs) + else: + return None + elif not column.is_literal: + name = self._truncated_identifier("colident", column.name) + else: + name = column.name + + if result_map is not None: + result_map[name.lower()] = (name, (column, ), column.type) + + if getattr(column, "is_literal", False): + name = self.escape_literal_column(name) + else: + name = self.preparer.quote(column, name) + + if column.table is None or not column.table.named_with_column: + return name + else: + if getattr(column.table, 'schema', None): + schema_prefix = self.preparer.quote(column.table, column.table.schema) + '.' + else: + schema_prefix = '' + return schema_prefix + self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name + + def escape_literal_column(self, text): + """provide escaping for the literal_column() construct.""" + + # TODO: some dialects might need different behavior here + return text.replace('%', '%%') + + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name + + def visit_index(self, index, **kwargs): + return index.name + + def visit_typeclause(self, typeclause, **kwargs): + return typeclause.type.dialect_impl(self.dialect).get_col_spec() + + def visit_textclause(self, textclause, **kwargs): + if textclause.typemap is not None: + for colname, type_ in textclause.typemap.iteritems(): + self.result_map[colname.lower()] = (colname, None, type_) + + def do_bindparam(m): + name = m.group(1) + if name in textclause.bindparams: + return self.process(textclause.bindparams[name]) + else: + return self.bindparam_string(name) + + # un-escape any \:params + return BIND_PARAMS_ESC.sub(lambda m: m.group(1), + BIND_PARAMS.sub(do_bindparam, textclause.text) + ) + + def visit_null(self, null, **kwargs): + return 'NULL' + + def visit_clauselist(self, clauselist, **kwargs): + sep = clauselist.operator + if sep is None: + sep = " " + elif sep == operators.comma_op: + sep = ', ' + else: + sep = " " + self.operator_string(clauselist.operator) + " " + return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None]) + + def visit_calculatedclause(self, clause, **kwargs): + return self.process(clause.clause_expr) + + def visit_cast(self, cast, **kwargs): + return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) + + def visit_function(self, func, result_map=None, **kwargs): + if result_map is not None: + result_map[func.name.lower()] = (func.name, None, func.type) + + name = self.function_string(func) + + if callable(name): + return name(*[self.process(x) for x in func.clauses]) + else: + return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} + + def function_argspec(self, func): + return self.process(func.clause_expr) + + def function_string(self, func): + return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s")) + + def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): + stack_entry = {'select':cs} + + if asfrom: + stack_entry['is_subquery'] = True + elif self.stack and self.stack[-1].get('select'): + stack_entry['is_subquery'] = True + self.stack.append(stack_entry) + + text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") + group_by = self.process(cs._group_by_clause, asfrom=asfrom) + if group_by: + text += " GROUP BY " + group_by + + text += self.order_by_clause(cs) + text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def visit_unary(self, unary, **kwargs): + s = self.process(unary.element) + if unary.operator: + s = self.operator_string(unary.operator) + " " + s + if unary.modifier: + s = s + " " + self.operator_string(unary.modifier) + return s + + def visit_binary(self, binary, **kwargs): + op = self.operator_string(binary.operator) + if callable(op): + return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) + else: + return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + def operator_string(self, operator): + return self.operators.get(operator, str(operator)) + + def visit_bindparam(self, bindparam, **kwargs): + name = self._truncate_bindparam(bindparam) + if name in self.binds: + existing = self.binds[name] + if existing is not bindparam and (existing.unique or bindparam.unique): + raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + self.binds[bindparam.key] = self.binds[name] = bindparam + return self.bindparam_string(name) + + def _truncate_bindparam(self, bindparam): + if bindparam in self.bind_names: + return self.bind_names[bindparam] + + bind_name = bindparam.key + bind_name = self._truncated_identifier("bindparam", bind_name) + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + + return bind_name + + def _truncated_identifier(self, ident_class, name): + if (ident_class, name) in self.generated_ids: + return self.generated_ids[(ident_class, name)] + + anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) + + if len(anonname) > self.dialect.max_identifier_length: + counter = self.generated_ids.get(ident_class, 1) + truncname = anonname[0:self.dialect.max_identifier_length - 6] + "_" + hex(counter)[2:] + self.generated_ids[ident_class] = counter + 1 + else: + truncname = anonname + self.generated_ids[(ident_class, name)] = truncname + return truncname + + def _process_anon(self, match): + (ident, derived) = match.group(1,2) + + key = ('anonymous', ident) + if key in self.generated_ids: + return self.generated_ids[key] + else: + anonymous_counter = self.generated_ids.get(('anon_counter', derived), 1) + newname = derived + "_" + str(anonymous_counter) + self.generated_ids[('anon_counter', derived)] = anonymous_counter + 1 + self.generated_ids[key] = newname + return newname + + def _anonymize(self, name): + return ANONYMOUS_LABEL.sub(self._process_anon, name) + + def bindparam_string(self, name): + if self.positional: + self.positiontup.append(name) + + return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} + + def visit_alias(self, alias, asfrom=False, **kwargs): + if asfrom: + return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) + else: + return self.process(alias.original, **kwargs) + + def label_select_column(self, select, column, asfrom): + """label columns present in a select().""" + + if isinstance(column, sql._Label): + return column + + if select.use_labels and getattr(column, '_label', None): + return column.label(column._label) + + if \ + asfrom and \ + isinstance(column, sql._ColumnClause) and \ + not column.is_literal and \ + column.table is not None and \ + not isinstance(column.table, sql.Select): + return column.label(column.name) + elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and (not hasattr(column, 'name') or isinstance(column, sql._Function)): + return column.label(column.anon_label) + else: + return column + + def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, **kwargs): + + stack_entry = {'select':select} + prev_entry = self.stack and self.stack[-1] or None + + if asfrom or (prev_entry and 'select' in prev_entry): + stack_entry['is_subquery'] = True + if prev_entry and 'iswrapper' in prev_entry: + column_clause_args = {'result_map':self.result_map} + else: + column_clause_args = {} + elif iswrapper: + column_clause_args = {} + stack_entry['iswrapper'] = True + else: + column_clause_args = {'result_map':self.result_map} + + if self.stack and 'from' in self.stack[-1]: + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + + froms = select._get_display_froms(existingfroms) + + correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms]))) + + # TODO: might want to propigate existing froms for select(select(select)) + # where innermost select should correlate to outermost +# if existingfroms: +# correlate_froms = correlate_froms.union(existingfroms) + stack_entry['from'] = correlate_froms + self.stack.append(stack_entry) + + # the actual list of columns to print in the SELECT column list. + inner_columns = util.OrderedSet( + [c for c in [ + self.process( + self.label_select_column(select, co, asfrom=asfrom), + **column_clause_args) + for co in select.inner_columns + ] + if c is not None] + ) + + text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " + text += self.get_select_precolumns(select) + text += ', '.join(inner_columns) + + from_strings = [] + for f in froms: + from_strings.append(self.process(f, asfrom=True)) + + if froms: + text += " \nFROM " + text += string.join(from_strings, ', ') + else: + text += self.default_from() + + if select._whereclause is not None: + t = self.process(select._whereclause) + if t: + text += " \nWHERE " + t + + group_by = self.process(select._group_by_clause) + if group_by: + text += " GROUP BY " + group_by + + if select._having is not None: + t = self.process(select._having) + if t: + text += " \nHAVING " + t + + text += self.order_by_clause(select) + text += (select._limit or select._offset) and self.limit_clause(select) or "" + text += self.for_update_clause(select) + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just before column list.""" + + return select._distinct and "DISTINCT " or "" + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def for_update_clause(self, select): + if select.for_update: + return " FOR UPDATE" + else: + return "" + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT -1" + text += " OFFSET " + str(select._offset) + return text + + def visit_table(self, table, asfrom=False, **kwargs): + if asfrom: + if getattr(table, "schema", None): + return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name) + else: + return self.preparer.quote(table, table.name) + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + + def visit_sequence(self, seq): + return None + + def visit_insert(self, insert_stmt): + self.isinsert = True + colparams = self._get_colparams(insert_stmt) + preparer = self.preparer + + insert = ' '.join(["INSERT"] + + [self.process(x) for x in insert_stmt._prefixes]) + + return (insert + " INTO %s (%s) VALUES (%s)" % + (preparer.format_table(insert_stmt.table), + ', '.join([preparer.quote(c[0], c[0].name) + for c in colparams]), + ', '.join([c[1] for c in colparams]))) + + def visit_update(self, update_stmt): + self.stack.append({'from':util.Set([update_stmt.table])}) + + self.isupdate = True + colparams = self._get_colparams(update_stmt) + + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ') + + if update_stmt._whereclause: + text += " WHERE " + self.process(update_stmt._whereclause) + + self.stack.pop(-1) + + return text + + def _get_colparams(self, stmt): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + """ + + def create_bind_param(col, value): + bindparam = sql.bindparam(col.key, value, type_=col.type) + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + self.postfetch = [] + self.prefetch = [] + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if self.column_keys is None and stmt.parameters is None: + return [(c, create_bind_param(c, None)) for c in stmt.table.columns] + + # if we have statement parameters - set defaults in the + # compiled params + if self.column_keys is None: + parameters = {} + else: + parameters = dict([(getattr(key, 'key', key), None) for key in self.column_keys]) + + if stmt.parameters is not None: + for k, v in stmt.parameters.iteritems(): + parameters.setdefault(getattr(k, 'key', k), v) + + # create a list of column assignment clauses as tuples + values = [] + for c in stmt.table.columns: + if c.key in parameters: + value = parameters[c.key] + if sql._is_literal(value): + value = create_bind_param(c, value) + else: + self.postfetch.append(c) + value = self.process(value.self_group()) + values.append((c, value)) + elif isinstance(c, schema.Column): + if self.isinsert: + if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): + if (((isinstance(c.default, schema.Sequence) and + not c.default.optional) or + not self.dialect.supports_pk_autoincrement) or + (c.default is not None and + not isinstance(c.default, schema.Sequence))): + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + elif isinstance(c.default, schema.ColumnDefault): + if isinstance(c.default.arg, sql.ClauseElement): + values.append((c, self.process(c.default.arg.self_group()))) + if not c.primary_key: + # dont add primary key column to postfetch + self.postfetch.append(c) + else: + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + elif isinstance(c.default, schema.PassiveDefault): + if not c.primary_key: + self.postfetch.append(c) + elif isinstance(c.default, schema.Sequence): + proc = self.process(c.default) + if proc is not None: + values.append((c, proc)) + if not c.primary_key: + self.postfetch.append(c) + elif self.isupdate: + if isinstance(c.onupdate, schema.ColumnDefault): + if isinstance(c.onupdate.arg, sql.ClauseElement): + values.append((c, self.process(c.onupdate.arg.self_group()))) + self.postfetch.append(c) + else: + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + elif isinstance(c.onupdate, schema.PassiveDefault): + self.postfetch.append(c) + return values + + def visit_delete(self, delete_stmt): + self.stack.append({'from':util.Set([delete_stmt.table])}) + self.isdelete = True + + text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + + if delete_stmt._whereclause: + text += " WHERE " + self.process(delete_stmt._whereclause) + + self.stack.pop(-1) + + return text + + def visit_savepoint(self, savepoint_stmt): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_release_savepoint(self, savepoint_stmt): + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def __str__(self): + return self.string or '' + +class DDLBase(engine.SchemaIterator): + def find_alterables(self, tables): + alterables = [] + class FindAlterables(schema.SchemaVisitor): + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and constraint.table in tables: + alterables.append(constraint) + findalterables = FindAlterables() + for table in tables: + for c in table.constraints: + findalterables.traverse(c) + return alterables + +class SchemaGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables and util.Set(tables) or None + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def get_column_specification(self, column, first_pk=False): + raise NotImplementedError() + + def visit_metadata(self, metadata): + collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] + for table in collection: + self.traverse_single(table) + if self.dialect.supports_alter: + for alterable in self.find_alterables(collection): + self.add_foreignkey(alterable) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-create']: + listener('before-create', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") + + separator = "\n" + + # if only one primary key, specify it along with the column + first_pk = False + for column in table.columns: + self.append(separator) + separator = ", \n" + self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) + if column.primary_key: + first_pk = True + for constraint in column.constraints: + self.traverse_single(constraint) + + # On some DB order is significant: visit PK first, then the + # other constraints (engine.ReflectionTest.testbasic failed on FB2) + if table.primary_key: + self.traverse_single(table.primary_key) + for constraint in [c for c in table.constraints if c is not table.primary_key]: + self.traverse_single(constraint) + + self.append("\n)%s\n\n" % self.post_create_table(table)) + self.execute() + + if hasattr(table, 'indexes'): + for index in table.indexes: + self.traverse_single(index) + + for listener in table.ddl_listeners['after-create']: + listener('after-create', table, self.connection) + + def post_create_table(self, table): + return '' + + def get_column_default_string(self, column): + if isinstance(column.default, schema.PassiveDefault): + if isinstance(column.default.arg, basestring): + return "'%s'" % column.default.arg + else: + return unicode(self._compile(column.default.arg, None)) + else: + return None + + def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" + compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) + compiler.compile() + return compiler + + def visit_check_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(constraint)) + self.append(" CHECK (%s)" % constraint.sqltext) + self.define_constraint_deferrability(constraint) + + def visit_column_check_constraint(self, constraint): + self.append(" CHECK (%s)" % constraint.sqltext) + self.define_constraint_deferrability(constraint) + + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) + self.append("PRIMARY KEY ") + self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint])) + self.define_constraint_deferrability(constraint) + + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and self.dialect.supports_alter: + return + self.append(", \n\t ") + self.define_foreign_key(constraint) + + def add_foreignkey(self, constraint): + self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) + self.define_foreign_key(constraint) + self.execute() + + def define_foreign_key(self, constraint): + preparer = self.preparer + if constraint.name is not None: + self.append("CONSTRAINT %s " % + preparer.format_constraint(constraint)) + table = list(constraint.elements)[0].column.table + self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]), + preparer.format_table(table), + ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements]) + )) + if constraint.ondelete is not None: + self.append(" ON DELETE %s" % constraint.ondelete) + if constraint.onupdate is not None: + self.append(" ON UPDATE %s" % constraint.onupdate) + self.define_constraint_deferrability(constraint) + + def visit_unique_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(constraint)) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint]))) + self.define_constraint_deferrability(constraint) + + def define_constraint_deferrability(self, constraint): + if constraint.deferrable is not None: + if constraint.deferrable: + self.append(" DEFERRABLE") + else: + self.append(" NOT DEFERRABLE") + if constraint.initially is not None: + self.append(" INITIALLY %s" % constraint.initially) + + def visit_column(self, column): + pass + + def visit_index(self, index): + preparer = self.preparer + self.append("CREATE ") + if index.unique: + self.append("UNIQUE ") + self.append("INDEX %s ON %s (%s)" \ + % (preparer.format_index(index), + preparer.format_table(index.table), + string.join([preparer.quote(c, c.name) for c in index.columns], ', '))) + self.execute() + + +class SchemaDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def visit_metadata(self, metadata): + collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] + if self.dialect.supports_alter: + for alterable in self.find_alterables(collection): + self.drop_foreignkey(alterable) + for table in collection: + self.traverse_single(table) + + def visit_index(self, index): + self.append("\nDROP INDEX " + self.preparer.format_index(index)) + self.execute() + + def drop_foreignkey(self, constraint): + self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( + self.preparer.format_table(constraint.table), + self.preparer.format_constraint(constraint))) + self.execute() + + def visit_table(self, table): + for listener in table.ddl_listeners['before-drop']: + listener('before-drop', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.append("\nDROP TABLE " + self.preparer.format_table(table)) + self.execute() + + for listener in table.ddl_listeners['after-drop']: + listener('after-drop', table, self.connection) + + +class IdentifierPreparer(object): + """Handle quoting and case-folding of identifiers based on options.""" + + reserved_words = RESERVED_WORDS + + legal_characters = LEGAL_CHARACTERS + + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + + def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): + """Construct a new ``IdentifierPreparer`` object. + + initial_quote + Character that begins a delimited identifier. + + final_quote + Character that ends a delimited identifier. Defaults to `initial_quote`. + + omit_schema + Prevent prepending schema name. Useful for databases that do + not support schemae. + """ + + self.dialect = dialect + self.initial_quote = initial_quote + self.final_quote = final_quote or self.initial_quote + self.omit_schema = omit_schema + self.__strings = {} + + def _escape_identifier(self, value): + """Escape an identifier. + + Subclasses should override this to provide database-dependent + escaping behavior. + """ + + return value.replace('"', '""') + + def _unescape_identifier(self, value): + """Canonicalize an escaped identifier. + + Subclasses should override this to provide database-dependent + unescaping behavior that reverses _escape_identifier. + """ + + return value.replace('""', '"') + + def quote_identifier(self, value): + """Quote an identifier. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + + return self.initial_quote + self._escape_identifier(value) + self.final_quote + + def _requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return (lc_value in self.reserved_words + or self.illegal_initial_characters.match(value[0]) + or not self.legal_characters.match(unicode(value)) + or (lc_value != value)) + + def quote(self, obj, ident): + if getattr(obj, 'quote', False): + return self.quote_identifier(ident) + if ident in self.__strings: + return self.__strings[ident] + else: + if self._requires_quotes(ident): + self.__strings[ident] = self.quote_identifier(ident) + else: + self.__strings[ident] = ident + return self.__strings[ident] + + def should_quote(self, object): + return object.quote or self._requires_quotes(object.name) + + def format_sequence(self, sequence, use_schema=True): + name = self.quote(sequence, sequence.name) + if not self.omit_schema and use_schema and sequence.schema is not None: + name = self.quote(sequence, sequence.schema) + "." + name + return name + + def format_label(self, label, name=None): + return self.quote(label, name or label.name) + + def format_alias(self, alias, name=None): + return self.quote(alias, name or alias.name) + + def format_savepoint(self, savepoint, name=None): + return self.quote(savepoint, name or savepoint.ident) + + def format_constraint(self, constraint): + return self.quote(constraint, constraint.name) + + def format_index(self, index): + return self.quote(index, index.name) + + def format_table(self, table, use_schema=True, name=None): + """Prepare a quoted table and schema name.""" + + if name is None: + name = table.name + result = self.quote(table, name) + if not self.omit_schema and use_schema and getattr(table, "schema", None): + result = self.quote(table, table.schema) + "." + result + return result + + def format_column(self, column, use_table=False, name=None, table_name=None): + """Prepare a quoted column name. + + deprecated. use preparer.quote(col, column.name) or combine with format_table() + """ + + if name is None: + name = column.name + if not getattr(column, 'is_literal', False): + if use_table: + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name) + else: + return self.quote(column, name) + else: + # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted + if use_table: + return self.format_table(column.table, use_schema=False, name=table_name) + "." + name + else: + return name + + def format_table_seq(self, table, use_schema=True): + """Format table name and schema as a tuple.""" + + # Dialects with more levels in their fully qualified references + # ('database', 'owner', etc.) could override this and return + # a longer sequence. + + if not self.omit_schema and use_schema and getattr(table, 'schema', None): + return (self.quote_identifier(table.schema), + self.format_table(table, use_schema=False)) + else: + return (self.format_table(table, use_schema=False), ) + + def unformat_identifiers(self, identifiers): + """Unpack 'schema.table.column'-like strings into components.""" + + try: + r = self._r_identifiers + except AttributeError: + initial, final, escaped_final = \ + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] + r = re.compile( + r'(?:' + r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' + r'|([^\.]+))(?=\.|$))+' % + { 'initial': initial, + 'final': final, + 'escaped': escaped_final }) + self._r_identifiers = r + + return [self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)]] diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py new file mode 100644 index 0000000000..867fdd69c3 --- /dev/null +++ b/lib/sqlalchemy/sql/expression.py @@ -0,0 +1,3517 @@ +# sql.py +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines the base components of SQL expression trees. + +All components are derived from a common base class +[sqlalchemy.sql.expression#ClauseElement]. Common behaviors are organized +based on class hierarchies, in some cases via mixins. + +All object construction from this package occurs via functions which +in some cases will construct composite ``ClauseElement`` structures +together, and in other cases simply return a single ``ClauseElement`` +constructed directly. The function interface affords a more "DSL-ish" +feel to constructing SQL expressions and also allows future class +reorganizations. + +Even though classes are not constructed directly from the outside, +most classes which have additional public methods are considered to be +public (i.e. have no leading underscore). Other classes which are +"semi-public" are marked with a single leading underscore; these +classes usually have few or no public methods and are less guaranteed +to stay the same in future releases. +""" + +import itertools, re +from sqlalchemy import util, exceptions +from sqlalchemy.sql import operators, visitors +from sqlalchemy import types as sqltypes + +functions, schema, sql_util = None, None, None +DefaultDialect, ClauseAdapter = None, None + +__all__ = [ + 'Alias', 'ClauseElement', + 'ColumnCollection', 'ColumnElement', + 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', + 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', + 'between', 'bindparam', 'case', 'cast', 'column', 'delete', + 'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', + 'modifier', 'collate', + 'insert', 'intersect', 'intersect_all', 'join', 'literal', + 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select', + 'subquery', 'table', 'text', 'union', 'union_all', 'update', ] + + + +def desc(column): + """Return a descending ``ORDER BY`` clause element. + + e.g.:: + + order_by = [desc(table1.mycol)] + """ + return _UnaryExpression(column, modifier=operators.desc_op) + +def asc(column): + """Return an ascending ``ORDER BY`` clause element. + + e.g.:: + + order_by = [asc(table1.mycol)] + """ + return _UnaryExpression(column, modifier=operators.asc_op) + +def outerjoin(left, right, onclause=None): + """Return an ``OUTER JOIN`` clause element. + + The returned object is an instance of [sqlalchemy.sql.expression#Join]. + + Similar functionality is also available via the ``outerjoin()`` + method on any [sqlalchemy.sql.expression#FromClause]. + + left + The left side of the join. + + right + The right side of the join. + + onclause + Optional criterion for the ``ON`` clause, is derived from + foreign key relationships established between left and right + otherwise. + + To chain joins together, use the ``join()`` or ``outerjoin()`` + methods on the resulting ``Join`` object. + """ + + return Join(left, right, onclause, isouter=True) + +def join(left, right, onclause=None, isouter=False): + """Return a ``JOIN`` clause element (regular inner join). + + The returned object is an instance of [sqlalchemy.sql.expression#Join]. + + Similar functionality is also available via the ``join()`` method + on any [sqlalchemy.sql.expression#FromClause]. + + left + The left side of the join. + + right + The right side of the join. + + onclause + Optional criterion for the ``ON`` clause, is derived from + foreign key relationships established between left and right + otherwise. + + To chain joins together, use the ``join()`` or ``outerjoin()`` + methods on the resulting ``Join`` object. + """ + + return Join(left, right, onclause, isouter) + +def select(columns=None, whereclause=None, from_obj=[], **kwargs): + """Returns a ``SELECT`` clause element. + + Similar functionality is also available via the ``select()`` + method on any [sqlalchemy.sql.expression#FromClause]. + + The returned object is an instance of [sqlalchemy.sql.expression#Select]. + + All arguments which accept ``ClauseElement`` arguments also accept + string arguments, which will be converted as appropriate into + either ``text()`` or ``literal_column()`` constructs. + + columns + A list of ``ClauseElement`` objects, typically ``ColumnElement`` + objects or subclasses, which will form the columns clause of the + resulting statement. For all members which are instances of + ``Selectable``, the individual ``ColumnElement`` members of the + ``Selectable`` will be added individually to the columns clause. + For example, specifying a ``Table`` instance will result in all + the contained ``Column`` objects within to be added to the + columns clause. + + This argument is not present on the form of ``select()`` + available on ``Table``. + + whereclause + A ``ClauseElement`` expression which will be used to form the + ``WHERE`` clause. + + from_obj + A list of ``ClauseElement`` objects which will be added to the + ``FROM`` clause of the resulting statement. Note that "from" + objects are automatically located within the columns and + whereclause ClauseElements. Use this parameter to explicitly + specify "from" objects which are not automatically locatable. + This could include ``Table`` objects that aren't otherwise + present, or ``Join`` objects whose presence will supercede that + of the ``Table`` objects already located in the other clauses. + + \**kwargs + Additional parameters include: + + autocommit + indicates this SELECT statement modifies the database, and + should be subject to autocommit behavior if no transaction + has been started. + + prefixes + a list of strings or ``ClauseElement`` objects to include + directly after the SELECT keyword in the generated statement, + for dialect-specific query features. + + distinct=False + when ``True``, applies a ``DISTINCT`` qualifier to the columns + clause of the resulting statement. + + use_labels=False + when ``True``, the statement will be generated using labels + for each column in the columns clause, which qualify each + column with its parent table's (or aliases) name so that name + conflicts between columns in different tables don't occur. + The format of the label is _. The "c" + collection of the resulting ``Select`` object will use these + names as well for targeting column members. + + for_update=False + when ``True``, applies ``FOR UPDATE`` to the end of the + resulting statement. Certain database dialects also support + alternate values for this parameter, for example mysql + supports "read" which translates to ``LOCK IN SHARE MODE``, + and oracle supports "nowait" which translates to ``FOR UPDATE + NOWAIT``. + + correlate=True + indicates that this ``Select`` object should have its + contained ``FromClause`` elements "correlated" to an enclosing + ``Select`` object. This means that any ``ClauseElement`` + instance within the "froms" collection of this ``Select`` + which is also present in the "froms" collection of an + enclosing select will not be rendered in the ``FROM`` clause + of this select statement. + + group_by + a list of ``ClauseElement`` objects which will comprise the + ``GROUP BY`` clause of the resulting select. + + having + a ``ClauseElement`` that will comprise the ``HAVING`` clause + of the resulting select when ``GROUP BY`` is used. + + order_by + a scalar or list of ``ClauseElement`` objects which will + comprise the ``ORDER BY`` clause of the resulting select. + + limit=None + a numerical value which usually compiles to a ``LIMIT`` + expression in the resulting select. Databases that don't + support ``LIMIT`` will attempt to provide similar + functionality. + + offset=None + a numeric value which usually compiles to an ``OFFSET`` + expression in the resulting select. Databases that don't + support ``OFFSET`` will attempt to provide similar + functionality. + + bind=None + an ``Engine`` or ``Connection`` instance to which the + resulting ``Select ` object will be bound. The ``Select`` + object will otherwise automatically bind to whatever + ``Connectable`` instances can be located within its contained + ``ClauseElement`` members. + + scalar=False + deprecated. Use select(...).as_scalar() to create a "scalar + column" proxy for an existing Select object. + """ + + if 'scalar' in kwargs: + util.warn_deprecated('scalar option is deprecated; see docs for details') + scalar = kwargs.pop('scalar', False) + s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) + if scalar: + return s.as_scalar() + else: + return s + +def subquery(alias, *args, **kwargs): + """Return an [sqlalchemy.sql.expression#Alias] object derived from a [sqlalchemy.sql.expression#Select]. + + name + alias name + + \*args, \**kwargs + + all other arguments are delivered to the [sqlalchemy.sql.expression#select()] + function. + """ + + return Select(*args, **kwargs).alias(alias) + +def insert(table, values=None, inline=False, **kwargs): + """Return an [sqlalchemy.sql.expression#Insert] clause element. + + Similar functionality is available via the ``insert()`` method on + [sqlalchemy.schema#Table]. + + table + The table to be inserted into. + + values + A dictionary which specifies the column specifications of the + ``INSERT``, and is optional. If left as None, the column + specifications are determined from the bind parameters used + during the compile phase of the ``INSERT`` statement. If the + bind parameters also are None during the compile phase, then the + column specifications will be generated from the full list of + table columns. + + prefixes + A list of modifier keywords to be inserted between INSERT and INTO, + see ``Insert.prefix_with``. + + inline + if True, SQL defaults will be compiled 'inline' into the statement + and not pre-executed. + + If both `values` and compile-time bind parameters are present, the + compile-time bind parameters override the information specified + within `values` on a per-key basis. + + The keys within `values` can be either ``Column`` objects or their + string identifiers. Each key may reference one of: + + * a literal data value (i.e. string, number, etc.); + * a Column object; + * a SELECT statement. + + If a ``SELECT`` statement is specified which references this + ``INSERT`` statement's table, the statement will be correlated + against the ``INSERT`` statement. + """ + + return Insert(table, values, inline=inline, **kwargs) + +def update(table, whereclause=None, values=None, inline=False, **kwargs): + """Return an [sqlalchemy.sql.expression#Update] clause element. + + Similar functionality is available via the ``update()`` method on + [sqlalchemy.schema#Table]. + + table + The table to be updated. + + whereclause + A ``ClauseElement`` describing the ``WHERE`` condition of the + ``UPDATE`` statement. + + values + A dictionary which specifies the ``SET`` conditions of the + ``UPDATE``, and is optional. If left as None, the ``SET`` + conditions are determined from the bind parameters used during + the compile phase of the ``UPDATE`` statement. If the bind + parameters also are None during the compile phase, then the + ``SET`` conditions will be generated from the full list of table + columns. + + inline + if True, SQL defaults will be compiled 'inline' into the statement + and not pre-executed. + + + If both `values` and compile-time bind parameters are present, the + compile-time bind parameters override the information specified + within `values` on a per-key basis. + + The keys within `values` can be either ``Column`` objects or their + string identifiers. Each key may reference one of: + + * a literal data value (i.e. string, number, etc.); + * a Column object; + * a SELECT statement. + + If a ``SELECT`` statement is specified which references this + ``UPDATE`` statement's table, the statement will be correlated + against the ``UPDATE`` statement. + """ + + return Update(table, whereclause=whereclause, values=values, inline=inline, **kwargs) + +def delete(table, whereclause = None, **kwargs): + """Return a [sqlalchemy.sql.expression#Delete] clause element. + + Similar functionality is available via the ``delete()`` method on + [sqlalchemy.schema#Table]. + + table + The table to be updated. + + whereclause + A ``ClauseElement`` describing the ``WHERE`` condition of the + ``UPDATE`` statement. + """ + + return Delete(table, whereclause, **kwargs) + +def and_(*clauses): + """Join a list of clauses together using the ``AND`` operator. + + The ``&`` operator is also overloaded on all + [sqlalchemy.sql.expression#_CompareMixin] subclasses to produce the same + result. + """ + if len(clauses) == 1: + return clauses[0] + return ClauseList(operator=operators.and_, *clauses) + +def or_(*clauses): + """Join a list of clauses together using the ``OR`` operator. + + The ``|`` operator is also overloaded on all + [sqlalchemy.sql.expression#_CompareMixin] subclasses to produce the same + result. + """ + + if len(clauses) == 1: + return clauses[0] + return ClauseList(operator=operators.or_, *clauses) + +def not_(clause): + """Return a negation of the given clause, i.e. ``NOT(clause)``. + + The ``~`` operator is also overloaded on all + [sqlalchemy.sql.expression#_CompareMixin] subclasses to produce the same + result. + """ + + return operators.inv(_literal_as_binds(clause)) + +def distinct(expr): + """Return a ``DISTINCT`` clause.""" + + return _UnaryExpression(expr, operator=operators.distinct_op) + +def between(ctest, cleft, cright): + """Return a ``BETWEEN`` predicate clause. + + Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``. + + The ``between()`` method on all [sqlalchemy.sql.expression#_CompareMixin] subclasses + provides similar functionality. + """ + + ctest = _literal_as_binds(ctest) + return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operators.and_, group=False), operators.between_op) + + +def case(whens, value=None, else_=None): + """Produce a ``CASE`` statement. + + whens + A sequence of pairs, or alternatively a dict, + to be translated into "WHEN / THEN" clauses. + + value + Optional for simple case statements, produces + a column expression as in "CASE WHEN ..." + + else\_ + Optional as well, for case defaults produces + the "ELSE" portion of the "CASE" statement. + + The expressions used for THEN and ELSE, + when specified as strings, will be interpreted + as bound values. To specify textual SQL expressions + for these, use the text() construct. + + The expressions used for the WHEN criterion + may only be literal strings when "value" is + present, i.e. CASE table.somecol WHEN "x" THEN "y". + Otherwise, literal strings are not accepted + in this position, and either the text() + or literal() constructs must be used to + interpret raw string values. + + Usage examples:: + + case([(orderline.c.qty > 100, item.c.specialprice), + (orderline.c.qty > 10, item.c.bulkprice) + ], else_=item.c.regularprice) + case(value=emp.c.type, whens={ + 'engineer': emp.c.salary * 1.1, + 'manager': emp.c.salary * 3, + }) + """ + try: + whens = util.dictlike_iteritems(whens) + except TypeError: + pass + + if value: + crit_filter = _literal_as_binds + else: + crit_filter = _no_literals + + whenlist = [ClauseList('WHEN', crit_filter(c), 'THEN', _literal_as_binds(r), operator=None) + for (c,r) in whens] + if else_ is not None: + whenlist.append(ClauseList('ELSE', _literal_as_binds(else_), operator=None)) + if whenlist: + type = list(whenlist[-1])[-1].type + else: + type = None + cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END']) + return cc + +def cast(clause, totype, **kwargs): + """Return a ``CAST`` function. + + Equivalent of SQL ``CAST(clause AS totype)``. + + Use with a [sqlalchemy.types#TypeEngine] subclass, i.e:: + + cast(table.c.unit_price * table.c.qty, Numeric(10,4)) + + or:: + + cast(table.c.timestamp, DATE) + """ + + return _Cast(clause, totype, **kwargs) + +def extract(field, expr): + """Return the clause ``extract(field FROM expr)``.""" + + expr = _BinaryExpression(text(field), expr, operators.from_) + return func.extract(expr) + +def collate(expression, collation): + """Return the clause ``expression COLLATE collation``.""" + + expr = _literal_as_binds(expression) + return _CalculatedClause( + expr, expr, _literal_as_text(collation), + operator=operators.collate, group=False) + +def exists(*args, **kwargs): + """Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object. + + The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by + itself or used as a subquery within an enclosing select. + + \*args, \**kwargs + all arguments are sent directly to the [sqlalchemy.sql.expression#select()] + function to produce a ``SELECT`` statement. + """ + + return _Exists(*args, **kwargs) + +def union(*selects, **kwargs): + """Return a ``UNION`` of multiple selectables. + + The returned object is an instance of [sqlalchemy.sql.expression#CompoundSelect]. + + A similar ``union()`` method is available on all + [sqlalchemy.sql.expression#FromClause] subclasses. + + \*selects + a list of [sqlalchemy.sql.expression#Select] instances. + + \**kwargs + available keyword arguments are the same as those of + [sqlalchemy.sql.expression#select()]. + """ + + return _compound_select('UNION', *selects, **kwargs) + +def union_all(*selects, **kwargs): + """Return a ``UNION ALL`` of multiple selectables. + + The returned object is an instance of [sqlalchemy.sql.expression#CompoundSelect]. + + A similar ``union_all()`` method is available on all + [sqlalchemy.sql.expression#FromClause] subclasses. + + \*selects + a list of [sqlalchemy.sql.expression#Select] instances. + + \**kwargs + available keyword arguments are the same as those of + [sqlalchemy.sql.expression#select()]. + """ + + return _compound_select('UNION ALL', *selects, **kwargs) + +def except_(*selects, **kwargs): + """Return an ``EXCEPT`` of multiple selectables. + + The returned object is an instance of [sqlalchemy.sql.expression#CompoundSelect]. + + \*selects + a list of [sqlalchemy.sql.expression#Select] instances. + + \**kwargs + available keyword arguments are the same as those of + [sqlalchemy.sql.expression#select()]. + """ + return _compound_select('EXCEPT', *selects, **kwargs) + +def except_all(*selects, **kwargs): + """Return an ``EXCEPT ALL`` of multiple selectables. + + The returned object is an instance of [sqlalchemy.sql.expression#CompoundSelect]. + + \*selects + a list of [sqlalchemy.sql.expression#Select] instances. + + \**kwargs + available keyword arguments are the same as those of + [sqlalchemy.sql.expression#select()]. + """ + return _compound_select('EXCEPT ALL', *selects, **kwargs) + +def intersect(*selects, **kwargs): + """Return an ``INTERSECT`` of multiple selectables. + + The returned object is an instance of [sqlalchemy.sql.expression#CompoundSelect]. + + \*selects + a list of [sqlalchemy.sql.expression#Select] instances. + + \**kwargs + available keyword arguments are the same as those of + [sqlalchemy.sql.expression#select()]. + """ + return _compound_select('INTERSECT', *selects, **kwargs) + +def intersect_all(*selects, **kwargs): + """Return an ``INTERSECT ALL`` of multiple selectables. + + The returned object is an instance of [sqlalchemy.sql.expression#CompoundSelect]. + + \*selects + a list of [sqlalchemy.sql.expression#Select] instances. + + \**kwargs + available keyword arguments are the same as those of + [sqlalchemy.sql.expression#select()]. + """ + return _compound_select('INTERSECT ALL', *selects, **kwargs) + +def alias(selectable, alias=None): + """Return an [sqlalchemy.sql.expression#Alias] object. + + An ``Alias`` represents any [sqlalchemy.sql.expression#FromClause] with + an alternate name assigned within SQL, typically using the ``AS`` + clause when generated, e.g. ``SELECT * FROM table AS aliasname``. + + Similar functionality is available via the ``alias()`` method + available on all ``FromClause`` subclasses. + + selectable + any ``FromClause`` subclass, such as a table, select + statement, etc.. + + alias + string name to be assigned as the alias. If ``None``, a + random name will be generated. + """ + + return Alias(selectable, alias=alias) + + +def literal(value, type_=None): + """Return a literal clause, bound to a bind parameter. + + Literal clauses are created automatically when non- + ``ClauseElement`` objects (such as strings, ints, dates, etc.) are + used in a comparison operation with a + [sqlalchemy.sql.expression#_CompareMixin] subclass, such as a ``Column`` + object. Use this function to force the generation of a literal + clause, which will be created as a + [sqlalchemy.sql.expression#_BindParamClause] with a bound value. + + value + the value to be bound. Can be any Python object supported by + the underlying DB-API, or is translatable via the given type + argument. + + type\_ + an optional [sqlalchemy.types#TypeEngine] which will provide + bind-parameter translation for this literal. + """ + + return _BindParamClause(None, value, type_=type_, unique=True) + +def label(name, obj): + """Return a [sqlalchemy.sql.expression#_Label] object for the given [sqlalchemy.sql.expression#ColumnElement]. + + A label changes the name of an element in the columns clause of a + ``SELECT`` statement, typically via the ``AS`` SQL keyword. + + This functionality is more conveniently available via the + ``label()`` method on ``ColumnElement``. + + name + label name + + obj + a ``ColumnElement``. + """ + + return _Label(name, obj) + +def column(text, type_=None): + """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. + + The object returned is an instance of [sqlalchemy.sql.expression#_ColumnClause], + which represents the "syntactical" portion of the schema-level + [sqlalchemy.schema#Column] object. + + text + the name of the column. Quoting rules will be applied to the + clause like any other column name. For textual column + constructs that are not to be quoted, use the + [sqlalchemy.sql.expression#literal_column()] function. + + type\_ + an optional [sqlalchemy.types#TypeEngine] object which will + provide result-set translation for this column. + + """ + + return _ColumnClause(text, type_=type_) + +def literal_column(text, type_=None): + """Return a textual column expression, as would be in the columns + clause of a ``SELECT`` statement. + + The object returned supports further expressions in the same way as any + other column object, including comparison, math and string operations. + The type\_ parameter is important to determine proper expression behavior + (such as, '+' means string concatenation or numerical addition based on + the type). + + text + the text of the expression; can be any SQL expression. Quoting rules + will not be applied. To specify a column-name expression which should + be subject to quoting rules, use the + [sqlalchemy.sql.expression#column()] function. + + type\_ + an optional [sqlalchemy.types#TypeEngine] object which will provide + result-set translation and additional expression semantics for this + column. If left as None the type will be NullType. + """ + + return _ColumnClause(text, type_=type_, is_literal=True) + +def table(name, *columns): + """Return a [sqlalchemy.sql.expression#Table] object. + + This is a primitive version of the [sqlalchemy.schema#Table] object, + which is a subclass of this object. + """ + + return TableClause(name, *columns) + +def bindparam(key, value=None, shortname=None, type_=None, unique=False): + """Create a bind parameter clause with the given key. + + value + a default value for this bind parameter. a bindparam with a + value is called a ``value-based bindparam``. + + type\_ + a sqlalchemy.types.TypeEngine object indicating the type of this + bind param, will invoke type-specific bind parameter processing + + shortname + deprecated. + + unique + if True, bind params sharing the same name will have their + underlying ``key`` modified to a uniquely generated name. + mostly useful with value-based bind params. + """ + + if isinstance(key, _ColumnClause): + return _BindParamClause(key.name, value, type_=key.type, unique=unique, shortname=shortname) + else: + return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname) + +def outparam(key, type_=None): + """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them. + + The ``outparam`` can be used like a regular function parameter. + The "output" value will be available from the + [sqlalchemy.engine#ResultProxy] object via its ``out_parameters`` + attribute, which returns a dictionary containing the values. + """ + + return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True) + +def text(text, bind=None, *args, **kwargs): + """Create literal text to be inserted into a query. + + When constructing a query from a ``select()``, ``update()``, + ``insert()`` or ``delete()``, using plain strings for argument + values will usually result in text objects being created + automatically. Use this function when creating textual clauses + outside of other ``ClauseElement`` objects, or optionally wherever + plain text is to be used. + + text + the text of the SQL statement to be created. use ``:`` + to specify bind parameters; they will be compiled to their + engine-specific format. + + bind + an optional connection or engine to be used for this text query. + + autocommit=True + indicates this SELECT statement modifies the database, and + should be subject to autocommit behavior if no transaction + has been started. + + bindparams + a list of ``bindparam()`` instances which can be used to define + the types and/or initial values for the bind parameters within + the textual statement; the keynames of the bindparams must match + those within the text of the statement. The types will be used + for pre-processing on bind values. + + typemap + a dictionary mapping the names of columns represented in the + ``SELECT`` clause of the textual statement to type objects, + which will be used to perform post-processing on columns within + the result set (for textual statements that produce result + sets). + + """ + + return _TextClause(text, bind=bind, *args, **kwargs) + +def null(): + """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement.""" + + return _Null() + +class _FunctionGenerator(object): + """Generate ``_Function`` objects based on getattr calls.""" + + def __init__(self, **opts): + self.__names = [] + self.opts = opts + + def __getattr__(self, name): + # passthru __ attributes; fixes pydoc + if name.startswith('__'): + try: + return self.__dict__[name] + except KeyError: + raise AttributeError(name) + + elif name.endswith('_'): + name = name[0:-1] + f = _FunctionGenerator(**self.opts) + f.__names = list(self.__names) + [name] + return f + + def __call__(self, *c, **kwargs): + o = self.opts.copy() + o.update(kwargs) + if len(self.__names) == 1: + global functions + if functions is None: + from sqlalchemy.sql import functions + func = getattr(functions, self.__names[-1].lower(), None) + if func is not None: + return func(*c, **o) + + return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) + +# "func" global - i.e. func.count() +func = _FunctionGenerator() + +# "modifier" global - i.e. modifier.distinct +# TODO: use UnaryExpression for this instead ? +modifier = _FunctionGenerator(group=False) + +def _clone(element): + return element._clone() + +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' predecessors.""" + + return itertools.chain(*[x._cloned_set for x in elements]) + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the enties present within 'a'. + + """ + all_overlap = util.Set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return a.intersection( + [ + elem for elem in a if all_overlap.intersection(elem._cloned_set) + ] + ) + +def _compound_select(keyword, *selects, **kwargs): + return CompoundSelect(keyword, *selects, **kwargs) + +def _is_literal(element): + return not isinstance(element, ClauseElement) + +def _literal_as_text(element): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): + return _TextClause(unicode(element)) + else: + return element + +def _literal_as_column(element): + if isinstance(element, Operators): + return element.clause_element() + elif _is_literal(element): + return literal_column(str(element)) + else: + return element + +def _literal_as_binds(element, name=None, type_=None): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): + if element is None: + return null() + else: + return _BindParamClause(name, element, type_=type_, unique=True) + else: + return element + +def _no_literals(element): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): + raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + else: + return element + +def _corresponding_column_or_error(fromclause, column, require_embedded=False): + c = fromclause.corresponding_column(column, require_embedded=require_embedded) + if not c: + raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) + return c + +def _selectable(element): + if hasattr(element, '__selectable__'): + return element.__selectable__() + elif isinstance(element, Selectable): + return element + else: + raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) + + +def is_column(col): + """True if ``col`` is an instance of ``ColumnElement``.""" + return isinstance(col, ColumnElement) + + +class _FigureVisitName(type): + def __init__(cls, clsname, bases, dict): + if not '__visit_name__' in cls.__dict__: + m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname) + x = m.group(1) + x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x) + cls.__visit_name__ = x.lower() + super(_FigureVisitName, cls).__init__(clsname, bases, dict) + +class ClauseElement(object): + """Base class for elements of a programmatically constructed SQL expression.""" + __metaclass__ = _FigureVisitName + + def _clone(self): + """Create a shallow copy of this ClauseElement. + + This method may be used by a generative API. Its also used as + part of the "deep" copy afforded by a traversal that combines + the _copy_internals() method. + """ + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + + # this is a marker that helps to "equate" clauses to each other + # when a Select returns its list of FROM clauses. the cloning + # process leaves around a lot of remnants of the previous clause + # typically in the form of column expressions still attached to the + # old table. + c._is_clone_of = self + + return c + + def _cloned_set(self): + f = self + while f is not None: + yield f + f = getattr(f, '_is_clone_of', None) + _cloned_set = property(_cloned_set) + + def _get_from_objects(self, **modifiers): + """Return objects represented in this ``ClauseElement`` that + should be added to the ``FROM`` list of a query, when this + ``ClauseElement`` is placed in the column clause of a + ``Select`` statement. + """ + + raise NotImplementedError(repr(self)) + + def unique_params(self, *optionaldict, **kwargs): + """Return a copy with ``bindparam()`` elments replaced. + + Same functionality as ``params()``, except adds `unique=True` + to affected bind parameters so that multiple statements can be + used. + """ + + return self._params(True, optionaldict, kwargs) + + def params(self, *optionaldict, **kwargs): + """Return a copy with ``bindparam()`` elments replaced. + + Returns a copy of this ClauseElement with ``bindparam()`` + elements replaced with values taken from the given dictionary:: + + >>> clause = column('x') + bindparam('foo') + >>> print clause.compile().params + {'foo':None} + >>> print clause.params({'foo':7}).compile().params + {'foo':7} + """ + + return self._params(False, optionaldict, kwargs) + + def _params(self, unique, optionaldict, kwargs): + if len(optionaldict) == 1: + kwargs.update(optionaldict[0]) + elif len(optionaldict) > 1: + raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument") + + def visit_bindparam(bind): + if bind.key in kwargs: + bind.value = kwargs[bind.key] + if unique: + bind._convert_to_unique() + return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True) + + def compare(self, other): + """Compare this ClauseElement to the given ClauseElement. + + Subclasses should override the default behavior, which is a + straight identity comparison. + """ + + return self is other + + def _copy_internals(self, clone=_clone): + """Reassign internal elements to be clones of themselves. + + Called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy. + """ + + pass + + def get_children(self, **kwargs): + """Return immediate child elements of this ``ClauseElement``. + + This is used for visit traversal. + + \**kwargs may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + """ + return [] + + def self_group(self, against=None): + return self + + def supports_execution(self): + """Return True if this clause element represents a complete executable statement.""" + + return False + + def bind(self): + """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" + + try: + if self._bind is not None: + return self._bind + except AttributeError: + pass + for f in self._get_from_objects(): + if f is self: + continue + engine = f.bind + if engine is not None: + return engine + else: + return None + bind = property(bind) + + def execute(self, *multiparams, **params): + """Compile and execute this ``ClauseElement``.""" + + e = self.bind + if e is None: + label = getattr(self, 'description', self.__class__.__name__) + msg = ('This %s is not bound and does not support direct ' + 'execution. Supply this statement to a Connection or ' + 'Engine for execution. Or, assign a bind to the statement ' + 'or the Metadata of its underlying tables to enable ' + 'implicit execution via this method.' % label) + raise exceptions.UnboundExecutionError(msg) + return e.execute_clauseelement(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Compile and execute this ``ClauseElement``, returning the result's scalar representation.""" + + return self.execute(*multiparams, **params).scalar() + + def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False): + """Compile this SQL expression. + + Uses the given ``Compiler``, or the given ``AbstractDialect`` + or ``Engine`` to create a ``Compiler``. If no `compiler` + arguments are given, tries to use the underlying ``Engine`` + this ``ClauseElement`` is bound to to create a ``Compiler``, + if any. + + Finally, if there is no bound ``Engine``, uses an + ``DefaultDialect`` to create a default ``Compiler``. + + `parameters` is a dictionary representing the default bind + parameters to be used with the statement. If `parameters` is + a list, it is assumed to be a list of dictionaries and the + first dictionary in the list is used with which to compile + against. + + The bind parameters can in some cases determine the output of + the compilation, such as for ``UPDATE`` and ``INSERT`` + statements the bind parameters that are present determine the + ``SET`` and ``VALUES`` clause of those statements. + """ + + if compiler is None: + if dialect is not None: + compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) + elif bind is not None: + compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline) + elif self.bind is not None: + compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline) + + if compiler is None: + global DefaultDialect + if DefaultDialect is None: + from sqlalchemy.engine.default import DefaultDialect + dialect = DefaultDialect() + compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) + compiler.compile() + return compiler + + def __str__(self): + return unicode(self.compile()).encode('ascii', 'backslashreplace') + + def __and__(self, other): + return and_(self, other) + + def __or__(self, other): + return or_(self, other) + + def __invert__(self): + return self._negate() + + def _negate(self): + if hasattr(self, 'negation_clause'): + return self.negation_clause + else: + return _UnaryExpression(self.self_group(against=operators.inv), operator=operators.inv, negate=None) + + def __repr__(self): + friendly = getattr(self, 'description', None) + if friendly is None: + return object.__repr__(self) + else: + return '<%s.%s at 0x%x; %s>' % ( + self.__module__, self.__class__.__name__, id(self), friendly) + + +class Operators(object): + def __and__(self, other): + return self.operate(operators.and_, other) + + def __or__(self, other): + return self.operate(operators.or_, other) + + def __invert__(self): + return self.operate(operators.inv) + + def op(self, opstring): + def op(b): + return self.operate(operators.op, opstring, b) + return op + + def clause_element(self): + raise NotImplementedError() + + def operate(self, op, *other, **kwargs): + raise NotImplementedError() + + def reverse_operate(self, op, other, **kwargs): + raise NotImplementedError() + +class ColumnOperators(Operators): + """Defines comparison and math operations.""" + + timetuple = None + """Hack, allows datetime objects to be compared on the LHS.""" + + def __lt__(self, other): + return self.operate(operators.lt, other) + + def __le__(self, other): + return self.operate(operators.le, other) + + def __eq__(self, other): + return self.operate(operators.eq, other) + + def __ne__(self, other): + return self.operate(operators.ne, other) + + def __gt__(self, other): + return self.operate(operators.gt, other) + + def __ge__(self, other): + return self.operate(operators.ge, other) + + def concat(self, other): + return self.operate(operators.concat_op, other) + + def like(self, other, escape=None): + return self.operate(operators.like_op, other, escape=escape) + + def ilike(self, other, escape=None): + return self.operate(operators.ilike_op, other, escape=escape) + + def in_(self, *other): + return self.operate(operators.in_op, other) + + def startswith(self, other, **kwargs): + return self.operate(operators.startswith_op, other, **kwargs) + + def endswith(self, other, **kwargs): + return self.operate(operators.endswith_op, other, **kwargs) + + def contains(self, other, **kwargs): + return self.operate(operators.contains_op, other, **kwargs) + + def desc(self): + return self.operate(operators.desc_op) + + def asc(self): + return self.operate(operators.asc_op) + + def collate(self, collation): + return self.operate(operators.collate, collation) + + def __radd__(self, other): + return self.reverse_operate(operators.add, other) + + def __rsub__(self, other): + return self.reverse_operate(operators.sub, other) + + def __rmul__(self, other): + return self.reverse_operate(operators.mul, other) + + def __rdiv__(self, other): + return self.reverse_operate(operators.div, other) + + def between(self, cleft, cright): + return self.operate(operators.between_op, cleft, cright) + + def distinct(self): + return self.operate(operators.distinct_op) + + def __add__(self, other): + return self.operate(operators.add, other) + + def __sub__(self, other): + return self.operate(operators.sub, other) + + def __mul__(self, other): + return self.operate(operators.mul, other) + + def __div__(self, other): + return self.operate(operators.div, other) + + def __mod__(self, other): + return self.operate(operators.mod, other) + + def __truediv__(self, other): + return self.operate(operators.truediv, other) + +class _CompareMixin(ColumnOperators): + """Defines comparison and math operations for ``ClauseElement`` instances.""" + + def __compare(self, op, obj, negate=None, reverse=False, **kwargs): + if obj is None or isinstance(obj, _Null): + if op == operators.eq: + return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot) + elif op == operators.ne: + return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_) + else: + raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") + else: + obj = self._check_literal(obj) + + if reverse: + return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + else: + return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + + def __operate(self, op, obj, reverse=False): + obj = self._check_literal(obj) + + type_ = self._compare_type(obj) + + if reverse: + return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_) + else: + return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_) + + # a mapping of operators with the method they use, along with their negated + # operator for comparison operators + operators = { + operators.add : (__operate,), + operators.mul : (__operate,), + operators.sub : (__operate,), + operators.div : (__operate,), + operators.mod : (__operate,), + operators.truediv : (__operate,), + operators.lt : (__compare, operators.ge), + operators.le : (__compare, operators.gt), + operators.ne : (__compare, operators.eq), + operators.gt : (__compare, operators.le), + operators.ge : (__compare, operators.lt), + operators.eq : (__compare, operators.ne), + operators.like_op : (__compare, operators.notlike_op), + operators.ilike_op : (__compare, operators.notilike_op), + } + + def operate(self, op, *other, **kwargs): + o = _CompareMixin.operators[op] + return o[0](self, op, other[0], *o[1:], **kwargs) + + def reverse_operate(self, op, other, **kwargs): + o = _CompareMixin.operators[op] + return o[0](self, op, other, reverse=True, *o[1:], **kwargs) + + def in_(self, *other): + return self._in_impl(operators.in_op, operators.notin_op, *other) + + def _in_impl(self, op, negate_op, *other): + # Handle old style *args argument passing + if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)): + util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable') + seq_or_selectable = other + else: + seq_or_selectable = other[0] + + if isinstance(seq_or_selectable, Selectable): + return self.__compare( op, seq_or_selectable, negate=negate_op) + + # Handle non selectable arguments as sequences + args = [] + for o in seq_or_selectable: + if not _is_literal(o): + if not isinstance( o, _CompareMixin): + raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) + else: + o = self._bind_param(o) + args.append(o) + + if len(args) == 0: + # Special case handling for empty IN's + return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1'))) + + return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op) + + def startswith(self, other, escape=None): + """Produce the clause ``LIKE '%'``""" + + # use __radd__ to force string concat behavior + return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other)), escape=escape) + + def endswith(self, other, escape=None): + """Produce the clause ``LIKE '%'``""" + + return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String) + self._check_literal(other), escape=escape) + + def contains(self, other, escape=None): + """Produce the clause ``LIKE '%%'``""" + + return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String) + self._check_literal(other) + literal_column("'%'", type_=sqltypes.String), escape=escape) + + def label(self, name): + """Produce a column label, i.e. `` AS ``. + + if 'name' is None, an anonymous label name will be generated. + """ + return _Label(name, self, self.type) + + def desc(self): + """Produce a DESC clause, i.e. `` DESC``""" + + return desc(self) + + def asc(self): + """Produce a ASC clause, i.e. `` ASC``""" + + return asc(self) + + def distinct(self): + """Produce a DISTINCT clause, i.e. ``DISTINCT ``""" + return _UnaryExpression(self, operator=operators.distinct_op) + + def between(self, cleft, cright): + """Produce a BETWEEN clause, i.e. `` BETWEEN AND ``""" + + return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operators.and_, group=False), operators.between_op) + + def collate(self, collation): + """Produce a COLLATE clause, i.e. `` COLLATE utf8_bin``""" + name = getattr(self, 'name', None) + return _CalculatedClause( + None, self, _literal_as_text(collation), + operator=operators.collate, group=False) + + def op(self, operator): + """produce a generic operator function. + + e.g.:: + + somecolumn.op("*")(5) + + produces:: + + somecolumn * 5 + + operator + a string which will be output as the infix operator between + this ``ClauseElement`` and the expression passed to the + generated function. + """ + return lambda other: self.__operate(operator, other) + + def _bind_param(self, obj): + return _BindParamClause(None, obj, type_=self.type, unique=True) + + def _check_literal(self, other): + if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): + other.type = self.type + return other + elif isinstance(other, Operators): + return other.expression_element() + elif _is_literal(other): + return self._bind_param(other) + else: + return other + + def clause_element(self): + """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``.""" + return self + + def expression_element(self): + """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" + + return self + + def _compare_type(self, obj): + """Allow subclasses to override the type used in constructing + ``_BinaryExpression`` objects. + + Default return value is the type of the given object. + """ + + return obj.type + +class ColumnElement(ClauseElement, _CompareMixin): + """Represent an element that is usable within the "column clause" portion of a ``SELECT`` statement. + + This includes columns associated with tables, aliases, and + subqueries, expressions, function calls, SQL keywords such as + ``NULL``, literals, etc. ``ColumnElement`` is the ultimate base + class for all such elements. + + ``ColumnElement`` supports the ability to be a *proxy* element, + which indicates that the ``ColumnElement`` may be associated with + a ``Selectable`` which was derived from another ``Selectable``. + An example of a "derived" ``Selectable`` is an ``Alias`` of a + ``Table``. + + A ``ColumnElement``, by subclassing the ``_CompareMixin`` mixin + class, provides the ability to generate new ``ClauseElement`` + objects using Python expressions. See the ``_CompareMixin`` + docstring for more details. + """ + + primary_key = False + foreign_keys = [] + + def base_columns(self): + if hasattr(self, '_base_columns'): + return self._base_columns + self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) + return self._base_columns + base_columns = property(base_columns) + + def proxy_set(self): + if hasattr(self, '_proxy_set'): + return self._proxy_set + s = util.Set([self]) + if hasattr(self, 'proxies'): + for c in self.proxies: + s = s.union(c.proxy_set) + self._proxy_set = s + return s + proxy_set = property(proxy_set) + + def shares_lineage(self, othercolumn): + """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``. + """ + return len(self.proxy_set.intersection(othercolumn.proxy_set)) > 0 + + def _make_proxy(self, selectable, name=None): + """Create a new ``ColumnElement`` representing this + ``ColumnElement`` as it appears in the select list of a + descending selectable. + + """ + + if name: + co = _ColumnClause(name, selectable, type_=getattr(self, 'type', None)) + else: + name = str(self) + co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None)) + + co.proxies = [self] + selectable.columns[name]= co + return co + + def anon_label(self): + """provides a constant 'anonymous label' for this ColumnElement. + + This is a label() expression which will be named at compile time. + The same label() is returned each time anon_label is called so + that expressions can reference anon_label multiple times, producing + the same label name at compile time. + + the compiler uses this function automatically at compile time + for expressions that are known to be 'unnamed' like binary + expressions and function calls. + """ + + if not hasattr(self, '_ColumnElement__anon_label'): + self.__anon_label = "{ANON %d %s}" % (id(self), getattr(self, 'name', 'anon')) + return self.__anon_label + anon_label = property(anon_label) + +class ColumnCollection(util.OrderedProperties): + """An ordered dictionary that stores a list of ColumnElement + instances. + + Overrides the ``__eq__()`` method to produce SQL clauses between + sets of correlated columns. + """ + + def __init__(self, *cols): + super(ColumnCollection, self).__init__() + [self.add(c) for c in cols] + + def __str__(self): + return repr([str(c) for c in self]) + + def replace(self, column): + """add the given column to this collection, removing unaliased versions of this column + as well as existing columns with the same key. + + e.g.:: + + t = Table('sometable', Column('col1', Integer)) + t.replace_unalised(Column('col1', Integer, key='columnone')) + + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. + + Used by schema.Column to override columns during table reflection. + """ + + if column.name in self and column.key != column.name: + other = self[column.name] + if other.name == other.key: + del self[other.name] + util.OrderedProperties.__setitem__(self, column.key, column) + + def add(self, column): + """Add a column to this collection. + + The key attribute of the column will be used as the hash key + for this dictionary. + """ + + self[column.key] = column + + def __setitem__(self, key, value): + if key in self: + # this warning is primarily to catch select() statements which have conflicting + # column names in their exported columns collection + existing = self[key] + if not existing.shares_lineage(value): + table = getattr(existing, 'table', None) and existing.table.description + util.warn(("Column %r on table %r being replaced by another " + "column with the same key. Consider use_labels " + "for select() statements.") % (key, table)) + util.OrderedProperties.__setitem__(self, key, value) + + def remove(self, column): + del self[column.key] + + def extend(self, iter): + for c in iter: + self.add(c) + + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c==local) + return and_(*l) + + def __contains__(self, other): + if not isinstance(other, basestring): + raise exceptions.ArgumentError("__contains__ requires a string argument") + return util.OrderedProperties.__contains__(self, other) + + def contains_column(self, col): + # have to use a Set here, because it will compare the identity + # of the column, not just using "==" for comparison which will always return a + # "True" value (i.e. a BinaryClause...) + return col in util.Set(self) + +class ColumnSet(util.OrderedSet): + def contains_column(self, col): + return col in self + + def extend(self, cols): + for col in cols: + self.add(col) + + def __add__(self, other): + return list(self) + list(other) + + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c==local) + return and_(*l) + +class Selectable(ClauseElement): + """mark a class as being selectable""" + +class FromClause(Selectable): + """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement.""" + + __visit_name__ = 'fromclause' + named_with_column=False + _hide_froms = [] + + def _get_from_objects(self, **modifiers): + return [] + + def default_order_by(self): + return [self.oid_column] + + def count(self, whereclause=None, **params): + """return a SELECT COUNT generated against this ``FromClause``.""" + + if self.primary_key: + col = list(self.primary_key)[0] + else: + col = list(self.columns)[0] + return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) + + def select(self, whereclause=None, **params): + """return a SELECT of this ``FromClause``.""" + + return select([self], whereclause, **params) + + def join(self, right, onclause=None, isouter=False): + """return a join of this ``FromClause`` against another ``FromClause``.""" + + return Join(self, right, onclause, isouter) + + def outerjoin(self, right, onclause=None): + """return an outer join of this ``FromClause`` against another ``FromClause``.""" + + return Join(self, right, onclause, True) + + def alias(self, name=None): + """return an alias of this ``FromClause`` against another ``FromClause``.""" + + return Alias(self, name) + + def is_derived_from(self, fromclause): + """Return True if this FromClause is 'derived' from the given FromClause. + + An example would be an Alias of a Table is derived from that Table. + """ + return fromclause in util.Set(self._cloned_set) + + def replace_selectable(self, old, alias): + """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" + + global ClauseAdapter + if ClauseAdapter is None: + from sqlalchemy.sql.util import ClauseAdapter + return ClauseAdapter(alias).traverse(self, clone=True) + + def correspond_on_equivalents(self, column, equivalents): + col = self.corresponding_column(column, require_embedded=True) + if col is None and col in equivalents: + for equiv in equivalents[col]: + nc = self.corresponding_column(equiv, require_embedded=True) + if nc: + return nc + return col + + def corresponding_column(self, column, require_embedded=False): + """Given a ``ColumnElement``, return the exported ``ColumnElement`` + object from this ``Selectable`` which corresponds to that + original ``Column`` via a common anscestor column. + + column + the target ``ColumnElement`` to be matched + + require_embedded + only return corresponding columns for the given + ``ColumnElement``, if the given ``ColumnElement`` is + actually present within a sub-element of this + ``FromClause``. Normally the column will match if it merely + shares a common anscestor with one of the exported columns + of this ``FromClause``. + """ + + # dont dig around if the column is locally present + if self.c.contains_column(column): + return column + + col, intersect = None, None + target_set = column.proxy_set + cols = self.c + if self.oid_column: + cols += [self.oid_column] + for c in cols: + i = c.proxy_set.intersection(target_set) + if i and \ + (not require_embedded or c.proxy_set.issuperset(target_set)) and \ + (intersect is None or len(i) > len(intersect)): + col, intersect = c, i + return col + + def description(self): + """a brief description of this FromClause. + + Used primarily for error message formatting. + """ + return getattr(self, 'name', self.__class__.__name__ + " object") + description = property(description) + + def _reset_exported(self): + # delete all the "generated" collections of columns for a + # newly cloned FromClause, so that they will be re-derived + # from the item. this is because FromClause subclasses, when + # cloned, need to reestablish new "proxied" columns that are + # linked to the new item + for attr in ('_columns', '_primary_key' '_foreign_keys', '_oid_column', '_embedded_columns', '_all_froms'): + if hasattr(self, attr): + delattr(self, attr) + + def _expr_attr_func(name): + get = util.attrgetter(name) + def attr(self): + try: + return get(self) + except AttributeError: + self._export_columns() + return get(self) + return property(attr) + + columns = c = _expr_attr_func('_columns') + primary_key = _expr_attr_func('_primary_key') + foreign_keys = _expr_attr_func('_foreign_keys') + oid_column = _expr_attr_func('_oid_column') + + def _export_columns(self): + """Initialize column collections.""" + + if hasattr(self, '_columns'): + return + self._columns = ColumnCollection() + self._primary_key = ColumnSet() + self._foreign_keys = util.Set() + self._oid_column = None + self._populate_column_collection() + + def _populate_column_collection(self): + pass + +class _BindParamClause(ClauseElement, _CompareMixin): + """Represent a bind parameter. + + Public constructor is the ``bindparam()`` function. + """ + + __visit_name__ = 'bindparam' + + def __init__(self, key, value, type_=None, unique=False, isoutparam=False, shortname=None): + """Construct a _BindParamClause. + + key + the key for this bind param. Will be used in the generated + SQL statement for dialects that use named parameters. This + value may be modified when part of a compilation operation, + if other ``_BindParamClause`` objects exist with the same + key, or if its length is too long and truncation is + required. + + value + Initial value for this bind param. This value may be + overridden by the dictionary of parameters sent to statement + compilation/execution. + + shortname + deprecated. + + type\_ + A ``TypeEngine`` object that will be used to pre-process the + value corresponding to this ``_BindParamClause`` at + execution time. + + unique + if True, the key name of this BindParamClause will be + modified if another ``_BindParamClause`` of the same name + already has been located within the containing + ``ClauseElement``. + + isoutparam + if True, the parameter should be treated like a stored procedure "OUT" + parameter. + """ + + if unique: + self.key = "{ANON %d %s}" % (id(self), key or 'param') + else: + self.key = key or "{ANON %d param}" % id(self) + self._orig_key = key or 'param' + self.unique = unique + self.value = value + self.isoutparam = isoutparam + self.shortname = shortname + + if type_ is None: + self.type = sqltypes.type_map.get(type(value), sqltypes.NullType)() + elif isinstance(type_, type): + self.type = type_() + else: + self.type = type_ + + def _clone(self): + c = ClauseElement._clone(self) + if self.unique: + c.key = "{ANON %d %s}" % (id(c), c._orig_key or 'param') + return c + + def _convert_to_unique(self): + if not self.unique: + self.unique=True + self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param') + + def _get_from_objects(self, **modifiers): + return [] + + def bind_processor(self, dialect): + return self.type.dialect_impl(dialect).bind_processor(dialect) + + def _compare_type(self, obj): + if not isinstance(self.type, sqltypes.NullType): + return self.type + else: + return obj.type + + def compare(self, other): + """Compare this ``_BindParamClause`` to the given clause. + + Since ``compare()`` is meant to compare statement syntax, this + method returns True if the two ``_BindParamClauses`` have just + the same type. + """ + + return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ + + def __repr__(self): + return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type)) + +class _TypeClause(ClauseElement): + """Handle a type keyword in a SQL statement. + + Used by the ``Case`` statement. + """ + + __visit_name__ = 'typeclause' + + def __init__(self, type): + self.type = type + + def _get_from_objects(self, **modifiers): + return [] + +class _TextClause(ClauseElement): + """Represent a literal SQL text fragment. + + Public constructor is the ``text()`` function. + """ + + __visit_name__ = 'textclause' + + _bind_params_regex = re.compile(r'(? RIGHT``.""" + + def __init__(self, left, right, operator, type_=None, negate=None, modifiers=None): + ColumnElement.__init__(self) + self.left = _literal_as_text(left).self_group(against=operator) + self.right = _literal_as_text(right).self_group(against=operator) + self.operator = operator + self.type = sqltypes.to_instance(type_) + self.negate = negate + if modifiers is None: + self.modifiers = {} + else: + self.modifiers = modifiers + + def _get_from_objects(self, **modifiers): + return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) + + def _copy_internals(self, clone=_clone): + self.left = clone(self.left) + self.right = clone(self.right) + + def get_children(self, **kwargs): + return self.left, self.right + + def compare(self, other): + """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" + + return ( + isinstance(other, _BinaryExpression) and + self.operator == other.operator and + ( + self.left.compare(other.left) and + self.right.compare(other.right) or + ( + operators.is_commutative(self.operator) and + self.left.compare(other.right) and + self.right.compare(other.left) + ) + ) + ) + + def self_group(self, against=None): + # use small/large defaults for comparison so that unknown + # operators are always parenthesized + if self.operator != against and operators.is_precedent(self.operator, against): + return _Grouping(self) + else: + return self + + def _negate(self): + if self.negate is not None: + return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type, modifiers=self.modifiers) + else: + return super(_BinaryExpression, self)._negate() + +class _Exists(_UnaryExpression): + __visit_name__ = _UnaryExpression.__visit_name__ + + def __init__(self, *args, **kwargs): + kwargs['correlate'] = True + s = select(*args, **kwargs).as_scalar().self_group() + _UnaryExpression.__init__(self, s, operator=operators.exists) + + def select(self, whereclause=None, **params): + return select([self], whereclause, **params) + + def correlate(self, fromclause): + e = self._clone() + e.element = self.element.correlate(fromclause).self_group() + return e + + def where(self, clause): + """return a new exists() construct with the given expression added to its WHERE clause, joined + to the existing clause via AND, if any.""" + + e = self._clone() + e.element = self.element.where(clause).self_group() + return e + +class Join(FromClause): + """represent a ``JOIN`` construct between two ``FromClause`` elements. + + The public constructor function for ``Join`` is the module-level + ``join()`` function, as well as the ``join()`` method available + off all ``FromClause`` subclasses. + """ + + def __init__(self, left, right, onclause=None, isouter=False): + self.left = _selectable(left) + self.right = _selectable(right).self_group() + + if onclause is None: + self.onclause = self.__match_primaries(self.left, self.right) + else: + self.onclause = onclause + + self.isouter = isouter + self.__folded_equivalents = None + + def description(self): + return "Join object on %s(%d) and %s(%d)" % (self.left.description, id(self.left), self.right.description, id(self.right)) + description = property(description) + + def is_derived_from(self, fromclause): + return fromclause is self or self.left.is_derived_from(fromclause) or self.right.is_derived_from(fromclause) + + def self_group(self, against=None): + return _FromGrouping(self) + + def _populate_column_collection(self): + columns = [c for c in self.left.columns] + [c for c in self.right.columns] + + global sql_util + if not sql_util: + from sqlalchemy.sql import util as sql_util + self._primary_key.extend(sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause)) + self._columns.update([(col._label, col) for col in columns]) + self._foreign_keys.update(itertools.chain(*[col.foreign_keys for col in columns])) + self._oid_column = self.left.oid_column + + def _copy_internals(self, clone=_clone): + self._reset_exported() + self.left = clone(self.left) + self.right = clone(self.right) + self.onclause = clone(self.onclause) + self.__folded_equivalents = None + + def get_children(self, **kwargs): + return self.left, self.right, self.onclause + + def __match_primaries(self, primary, secondary): + global sql_util + if not sql_util: + from sqlalchemy.sql import util as sql_util + return sql_util.join_condition(primary, secondary) + + def select(self, whereclause=None, fold_equivalents=False, **kwargs): + """Create a ``Select`` from this ``Join``. + + whereclause + the WHERE criterion that will be sent to the ``select()`` + function + + fold_equivalents + based on the join criterion of this ``Join``, do not include + repeat column names in the column list of the resulting + select, for columns that are calculated to be "equivalent" + based on the join criterion of this ``Join``. This will + recursively apply to any joins directly nested by this one + as well. + + \**kwargs + all other kwargs are sent to the underlying ``select()`` function. + See the ``select()`` module level function for details. + """ + + if fold_equivalents: + global sql_util + if not sql_util: + from sqlalchemy.sql import util as sql_util + collist = sql_util.folded_equivalents(self) + else: + collist = [self.left, self.right] + + return select(collist, whereclause, from_obj=[self], **kwargs) + + def bind(self): + return self.left.bind or self.right.bind + bind = property(bind) + + def alias(self, name=None): + """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it. + + The ``Select`` is not correlating. + """ + + return self.select(use_labels=True, correlate=False).alias(name) + + def _hide_froms(self): + return itertools.chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set]) + _hide_froms = property(_hide_froms) + + def _get_from_objects(self, **modifiers): + return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) + +class Alias(FromClause): + """Represents an table or selectable alias (AS). + + Represents an alias, as typically applied to any table or + sub-select within a SQL statement using the ``AS`` keyword (or + without the keyword on certain databases such as Oracle). + + This object is constructed from the ``alias()`` module level + function as well as the ``alias()`` method available on all + ``FromClause`` subclasses. + """ + + named_with_column = True + + def __init__(self, selectable, alias=None): + baseselectable = selectable + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable + self.selectable = selectable + if alias is None: + if self.original.named_with_column: + alias = getattr(self.original, 'name', None) + alias = '{ANON %d %s}' % (id(self), alias or 'anon') + self.name = alias + + def description(self): + return self.name.encode('ascii', 'backslashreplace') + description = property(description) + + def is_derived_from(self, fromclause): + if fromclause in util.Set(self._cloned_set): + return True + return self.selectable.is_derived_from(fromclause) + + def supports_execution(self): + return self.original.supports_execution() + + def _table_iterator(self): + return self.original._table_iterator() + + def _populate_column_collection(self): + for col in self.selectable.columns: + col._make_proxy(self) + if self.selectable.oid_column is not None: + self._oid_column = self.selectable.oid_column._make_proxy(self) + + def _copy_internals(self, clone=_clone): + self._reset_exported() + self.selectable = _clone(self.selectable) + baseselectable = self.selectable + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable + + def get_children(self, column_collections=True, aliased_selectables=True, **kwargs): + if column_collections: + for c in self.c: + yield c + if aliased_selectables: + yield self.selectable + + def _get_from_objects(self, **modifiers): + return [self] + + def bind(self): + return self.selectable.bind + bind = property(bind) + +class _ColumnElementAdapter(ColumnElement): + """Adapts a ClauseElement which may or may not be a + ColumnElement subclass itself into an object which + acts like a ColumnElement. + """ + + def __init__(self, elem): + self.elem = elem + self.type = getattr(elem, 'type', None) + + def key(self): + return self.elem.key + key = property(key) + + def _label(self): + try: + return self.elem._label + except AttributeError: + return self.anon_label + _label = property(_label) + + def _copy_internals(self, clone=_clone): + self.elem = clone(self.elem) + + def get_children(self, **kwargs): + return self.elem, + + def _get_from_objects(self, **modifiers): + return self.elem._get_from_objects(**modifiers) + + def __getattr__(self, attr): + return getattr(self.elem, attr) + + def __getstate__(self): + return {'elem':self.elem, 'type':self.type} + + def __setstate__(self, state): + self.elem = state['elem'] + self.type = state['type'] + +class _Grouping(_ColumnElementAdapter): + """Represent a grouping within a column expression""" + pass + +class _FromGrouping(FromClause): + """Represent a grouping of a FROM clause""" + __visit_name__ = 'grouping' + + def __init__(self, elem): + self.elem = elem + + def columns(self): + return self.elem.columns + columns = c = property(columns) + + def _hide_froms(self): + return self.elem._hide_froms + _hide_froms = property(_hide_froms) + + def get_children(self, **kwargs): + return self.elem, + + def _copy_internals(self, clone=_clone): + self.elem = clone(self.elem) + + def _get_from_objects(self, **modifiers): + return self.elem._get_from_objects(**modifiers) + + def __getattr__(self, attr): + return getattr(self.elem, attr) + +class _Label(ColumnElement): + """Represents a column label (AS). + + Represent a label, as typically applied to any column-level + element using the ``AS`` sql keyword. + + This object is constructed from the ``label()`` module level + function as well as the ``label()`` method available on all + ``ColumnElement`` subclasses. + """ + + def __init__(self, name, obj, type_=None): + while isinstance(obj, _Label): + obj = obj.obj + self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) + self.obj = obj.self_group(against=operators.as_) + self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) + + def key(self): + return self.name + key = property(key) + + def _label(self): + return self.name + _label = property(_label) + + def _proxy_attr(name): + def attr(self): + return getattr(self.obj, name) + return property(attr) + + proxies = _proxy_attr('proxies') + base_columns = _proxy_attr('base_columns') + proxy_set = _proxy_attr('proxy_set') + primary_key = _proxy_attr('primary_key') + foreign_keys = _proxy_attr('foreign_keys') + + def expression_element(self): + return self.obj + + def get_children(self, **kwargs): + return self.obj, + + def _copy_internals(self, clone=_clone): + self.obj = clone(self.obj) + + def _get_from_objects(self, **modifiers): + return self.obj._get_from_objects(**modifiers) + + def _make_proxy(self, selectable, name = None): + if isinstance(self.obj, (Selectable, ColumnElement)): + e = self.obj._make_proxy(selectable, name=self.name) + else: + e = column(self.name)._make_proxy(selectable=selectable) + e.proxies.append(self) + return e + +class _ColumnClause(ColumnElement): + """Represents a generic column expression from any textual string. + + This includes columns associated with tables, aliases and select + statements, but also any arbitrary text. May or may not be bound + to an underlying ``Selectable``. ``_ColumnClause`` is usually + created publically via the ``column()`` function or the + ``literal_column()`` function. + + text + the text of the element. + + selectable + parent selectable. + + type + ``TypeEngine`` object which can associate this ``_ColumnClause`` + with a type. + + is_literal + if True, the ``_ColumnClause`` is assumed to be an exact + expression that will be delivered to the output with no quoting + rules applied regardless of case sensitive settings. the + ``literal_column()`` function is usually used to create such a + ``_ColumnClause``. + """ + + def __init__(self, text, selectable=None, type_=None, _is_oid=False, is_literal=False): + ColumnElement.__init__(self) + self.key = self.name = text + self.table = selectable + self.type = sqltypes.to_instance(type_) + self._is_oid = _is_oid + self.__label = None + self.is_literal = is_literal + + def description(self): + return self.name.encode('ascii', 'backslashreplace') + description = property(description) + + def _clone(self): + # ColumnClause is immutable + return self + + def _label(self): + """Generate a 'label' string for this column. + """ + + # for a "literal" column, we've no idea what the text is + # therefore no 'label' can be automatically generated + if self.is_literal: + return None + if not self.__label: + if self.table and self.table.named_with_column: + if getattr(self.table, 'schema', None): + self.__label = self.table.schema + "_" + self.table.name + "_" + self.name + else: + self.__label = self.table.name + "_" + self.name + + if self.__label in self.table.c: + label = self.__label + counter = 1 + while label in self.table.c: + label = self.__label + "_" + str(counter) + counter +=1 + self.__label = label + else: + self.__label = self.name + return self.__label + + _label = property(_label) + + def label(self, name): + # if going off the "__label" property and its None, we have + # no label; return self + if name is None: + return self + else: + return super(_ColumnClause, self).label(name) + + def _get_from_objects(self, **modifiers): + if self.table is not None: + return [self.table] + else: + return [] + + def _bind_param(self, obj): + return _BindParamClause(self.name, obj, type_=self.type, unique=True) + + def _make_proxy(self, selectable, name = None): + # propigate the "is_literal" flag only if we are keeping our name, + # otherwise its considered to be a label + is_literal = self.is_literal and (name is None or name == self.name) + c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) + c.proxies = [self] + if not self._is_oid: + selectable.columns[c.name] = c + return c + + def _compare_type(self, obj): + return self.type + +class TableClause(FromClause): + """Represents a "table" construct. + + Note that this represents tables only as another syntactical + construct within SQL expressions; it does not provide schema-level + functionality. + """ + + named_with_column = True + + def __init__(self, name, *columns): + super(TableClause, self).__init__() + self.name = self.fullname = name + self._oid_column = _ColumnClause('oid', self, _is_oid=True) + self._columns = ColumnCollection() + self._primary_key = ColumnSet() + self._foreign_keys = util.Set() + for c in columns: + self.append_column(c) + + def _export_columns(self): + raise NotImplementedError() + + def description(self): + return self.name.encode('ascii', 'backslashreplace') + description = property(description) + + def _clone(self): + # TableClause is immutable + return self + + def append_column(self, c): + self._columns[c.name] = c + c.table = self + + def get_children(self, column_collections=True, **kwargs): + if column_collections: + return [c for c in self.c] + else: + return [] + + def count(self, whereclause=None, **params): + if self.primary_key: + col = list(self.primary_key)[0] + else: + col = list(self.columns)[0] + return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) + + def insert(self, values=None, inline=False, **kwargs): + return insert(self, values=values, inline=inline, **kwargs) + + def update(self, whereclause=None, values=None, inline=False, **kwargs): + return update(self, whereclause=whereclause, values=values, inline=inline, **kwargs) + + def delete(self, whereclause=None): + return delete(self, whereclause) + + def _get_from_objects(self, **modifiers): + return [self] + + +class _SelectBaseMixin(object): + """Base class for ``Select`` and ``CompoundSelects``.""" + + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, autocommit=False): + self.use_labels = use_labels + self.for_update = for_update + self._autocommit = autocommit + self._limit = limit + self._offset = offset + self._bind = bind + + self._order_by_clause = ClauseList(*util.to_list(order_by, [])) + self._group_by_clause = ClauseList(*util.to_list(group_by, [])) + + def as_scalar(self): + """return a 'scalar' representation of this selectable, which can be used + as a column expression. + + Typically, a select statement which has only one column in its columns clause + is eligible to be used as a scalar expression. + + The returned object is an instance of [sqlalchemy.sql.expression#_ScalarSelect]. + + """ + return _ScalarSelect(self) + + def apply_labels(self): + """return a new selectable with the 'use_labels' flag set to True. + + This will result in column expressions being generated using labels against their table + name, such as "SELECT somecolumn AS tablename_somecolumn". This allows selectables which + contain multiple FROM clauses to produce a unique set of column names regardless of name conflicts + among the individual FROM clauses. + + """ + s = self._generate() + s.use_labels = True + return s + + def label(self, name): + """return a 'scalar' representation of this selectable, embedded as a subquery + with a label. + + See also ``as_scalar()``. + + """ + return self.as_scalar().label(name) + + def supports_execution(self): + """part of the ClauseElement contract; returns ``True`` in all cases for this class.""" + + return True + + def autocommit(self): + """return a new selectable with the 'autocommit' flag set to True.""" + + s = self._generate() + s._autocommit = True + return s + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + s._reset_exported() + return s + + def limit(self, limit): + """return a new selectable with the given LIMIT criterion applied.""" + + s = self._generate() + s._limit = limit + return s + + def offset(self, offset): + """return a new selectable with the given OFFSET criterion applied.""" + + s = self._generate() + s._offset = offset + return s + + def order_by(self, *clauses): + """return a new selectable with the given list of ORDER BY criterion applied. + + The criterion will be appended to any pre-existing ORDER BY criterion. + + """ + s = self._generate() + s.append_order_by(*clauses) + return s + + def group_by(self, *clauses): + """return a new selectable with the given list of GROUP BY criterion applied. + + The criterion will be appended to any pre-existing GROUP BY criterion. + + """ + s = self._generate() + s.append_group_by(*clauses) + return s + + def append_order_by(self, *clauses): + """Append the given ORDER BY criterion applied to this selectable. + + The criterion will be appended to any pre-existing ORDER BY criterion. + + """ + if len(clauses) == 1 and clauses[0] is None: + self._order_by_clause = ClauseList() + else: + if getattr(self, '_order_by_clause', None): + clauses = list(self._order_by_clause) + list(clauses) + self._order_by_clause = ClauseList(*clauses) + + def append_group_by(self, *clauses): + """Append the given GROUP BY criterion applied to this selectable. + + The criterion will be appended to any pre-existing GROUP BY criterion. + + """ + if len(clauses) == 1 and clauses[0] is None: + self._group_by_clause = ClauseList() + else: + if getattr(self, '_group_by_clause', None): + clauses = list(self._group_by_clause) + list(clauses) + self._group_by_clause = ClauseList(*clauses) + + def _get_from_objects(self, is_where=False, **modifiers): + if is_where: + return [] + else: + return [self] + +class _ScalarSelect(_Grouping): + __visit_name__ = 'grouping' + + def __init__(self, elem): + self.elem = elem + cols = list(elem.inner_columns) + if len(cols) != 1: + raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") + self.type = cols[0].type + + def columns(self): + raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") + columns = c = property(columns) + + def self_group(self, **kwargs): + return self + + def _make_proxy(self, selectable, name): + return list(self.inner_columns)[0]._make_proxy(selectable, name) + + def _get_from_objects(self, **modifiers): + return [] + +class CompoundSelect(_SelectBaseMixin, FromClause): + def __init__(self, keyword, *selects, **kwargs): + self._should_correlate = kwargs.pop('correlate', False) + self.keyword = keyword + self.selects = [] + + numcols = None + + # some DBs do not like ORDER BY in the inner queries of a UNION, etc. + for n, s in enumerate(selects): + if not numcols: + numcols = len(s.c) + elif len(s.c) != numcols: + raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % + (1, len(self.selects[0].c), n+1, len(s.c)) + ) + if s._order_by_clause: + s = s.order_by(None) + # unions group from left to right, so don't group first select + if n: + self.selects.append(s.self_group(self)) + else: + self.selects.append(s) + + _SelectBaseMixin.__init__(self, **kwargs) + + def self_group(self, against=None): + return _FromGrouping(self) + + def _populate_column_collection(self): + for cols in zip(*[s.c for s in self.selects]): + proxy = cols[0]._make_proxy(self, name=self.use_labels and cols[0]._label or None) + proxy.proxies = cols + + oid_proxies = [ + c for c in [f.oid_column for f in self.selects] if c is not None + ] + + if oid_proxies: + col = oid_proxies[0]._make_proxy(self) + col.proxies = oid_proxies + self._oid_column = col + + def _copy_internals(self, clone=_clone): + self._reset_exported() + self.selects = [clone(s) for s in self.selects] + if hasattr(self, '_col_map'): + del self._col_map + for attr in ('_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, clone(getattr(self, attr))) + + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.c) or []) + \ + [self._order_by_clause, self._group_by_clause] + list(self.selects) + + def _table_iterator(self): + for s in self.selects: + for t in s._table_iterator(): + yield t + + def bind(self): + if self._bind: + return self._bind + for s in self.selects: + e = s.bind + if e: + return e + else: + return None + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + +class Select(_SelectBaseMixin, FromClause): + """Represents a ``SELECT`` statement. + + Select statements support appendable clauses, as well as the + ability to execute themselves and return a result set. + """ + + def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs): + """Construct a Select object. + + The public constructor for Select is the + [sqlalchemy.sql.expression#select()] function; see that function for + argument descriptions. + + Additional generative and mutator methods are available on the + [sqlalchemy.sql.expression#_SelectBaseMixin] superclass. + """ + + self._should_correlate = correlate + self._distinct = distinct + + self._correlate = util.Set() + + if columns: + self._raw_columns = [ + isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c + for c in + [_literal_as_column(c) for c in columns] + ] + else: + self._raw_columns = [] + + if from_obj: + self._froms = util.Set([ + _is_literal(f) and _TextClause(f) or f + for f in util.to_list(from_obj) + ]) + else: + self._froms = util.Set() + + if whereclause: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None + + if having: + self._having = _literal_as_text(having) + else: + self._having = None + + if prefixes: + self._prefixes = [_literal_as_text(p) for p in prefixes] + else: + self._prefixes = [] + + _SelectBaseMixin.__init__(self, **kwargs) + + def _get_display_froms(self, existing_froms=None): + """Return the full list of 'from' clauses to be displayed. + + Takes into account a set of existing froms which may be + rendered in the FROM clause of enclosing selects; this Select + may want to leave those absent if it is automatically + correlating. + + """ + froms = util.OrderedSet() + + for col in self._raw_columns: + froms.update(col._get_from_objects()) + + if self._whereclause is not None: + froms.update(self._whereclause._get_from_objects(is_where=True)) + + if self._froms: + froms.update(self._froms) + + toremove = itertools.chain(*[f._hide_froms for f in froms]) + froms.difference_update(toremove) + + if len(froms) > 1 or self._correlate: + if self._correlate: + froms.difference_update(_cloned_intersection(froms, self._correlate)) + + if self._should_correlate and existing_froms: + froms.difference_update(_cloned_intersection(froms, existing_froms)) + + if not len(froms): + raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate() to control correlation manually." % self) + + return froms + + froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") + + def type(self): + raise exceptions.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") + type = property(type) + + def locate_all_froms(self): + """return a Set of all FromClause elements referenced by this Select. + + This set is a superset of that returned by the ``froms`` property, which + is specifically for those FromClause elements that would actually be rendered. + + """ + if hasattr(self, '_all_froms'): + return self._all_froms + + froms = util.Set( + itertools.chain(* + [self._froms] + + [f._get_from_objects() for f in self._froms] + + [col._get_from_objects() for col in self._raw_columns] + ) + ) + + if self._whereclause: + froms.update(self._whereclause._get_from_objects(is_where=True)) + + self._all_froms = froms + return froms + + def inner_columns(self): + """an iteratorof all ColumnElement expressions which would + be rendered into the columns clause of the resulting SELECT statement. + + """ + for c in self._raw_columns: + if isinstance(c, Selectable): + for co in c.columns: + yield co + else: + yield c + inner_columns = property(inner_columns) + + def is_derived_from(self, fromclause): + if self in util.Set(fromclause._cloned_set): + return True + + for f in self.locate_all_froms(): + if f.is_derived_from(fromclause): + return True + return False + + def _copy_internals(self, clone=_clone): + self._reset_exported() + from_cloned = dict([(f, clone(f)) for f in self._froms.union(self._correlate)]) + self._froms = util.Set([from_cloned[f] for f in self._froms]) + self._correlate = util.Set([from_cloned[f] for f in self._correlate]) + self._raw_columns = [clone(c) for c in self._raw_columns] + for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, clone(getattr(self, attr))) + + def get_children(self, column_collections=True, **kwargs): + """return child elements as per the ClauseElement specification.""" + + return (column_collections and list(self.columns) or []) + \ + list(self.locate_all_froms()) + \ + [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] + + def column(self, column): + """return a new select() construct with the given column expression added to its columns clause.""" + + s = self._generate() + column = _literal_as_column(column) + + if isinstance(column, _ScalarSelect): + column = column.self_group(against=operators.comma_op) + + s._raw_columns = s._raw_columns + [column] + return s + + def where(self, whereclause): + """return a new select() construct with the given expression added to its WHERE clause, joined + to the existing clause via AND, if any.""" + + s = self._generate() + s.append_whereclause(whereclause) + return s + + def having(self, having): + """return a new select() construct with the given expression added to its HAVING clause, joined + to the existing clause via AND, if any.""" + + s = self._generate() + s.append_having(having) + return s + + def distinct(self): + """return a new select() construct which will apply DISTINCT to its columns clause.""" + + s = self._generate() + s._distinct = True + return s + + def prefix_with(self, clause): + """return a new select() construct which will apply the given expression to the start of its + columns clause, not using any commas.""" + + s = self._generate() + clause = _literal_as_text(clause) + s._prefixes = s._prefixes + [clause] + return s + + def select_from(self, fromclause): + """return a new select() construct with the given FROM expression applied to its list of + FROM objects.""" + + s = self._generate() + if _is_literal(fromclause): + fromclause = _TextClause(fromclause) + + s._froms = s._froms.union([fromclause]) + return s + + def correlate(self, *fromclauses): + """return a new select() construct which will correlate the given FROM clauses to that + of an enclosing select(), if a match is found. + + By "match", the given fromclause must be present in this select's list of FROM objects + and also present in an enclosing select's list of FROM objects. + + Calling this method turns off the select's default behavior of "auto-correlation". Normally, + select() auto-correlates all of its FROM clauses to those of an embedded select when + compiled. + + If the fromclause is None, correlation is disabled for the returned select(). + + """ + s = self._generate() + s._should_correlate=False + if fromclauses == (None,): + s._correlate = util.Set() + else: + s._correlate = s._correlate.union(fromclauses) + return s + + def append_correlation(self, fromclause): + """append the given correlation expression to this select() construct.""" + + self._should_correlate=False + self._correlate = self._correlate.union([fromclause]) + + def append_column(self, column): + """append the given column expression to the columns clause of this select() construct.""" + + column = _literal_as_column(column) + + if isinstance(column, _ScalarSelect): + column = column.self_group(against=operators.comma_op) + + self._raw_columns = self._raw_columns + [column] + self._reset_exported() + + def append_prefix(self, clause): + """append the given columns clause prefix expression to this select() construct.""" + + clause = _literal_as_text(clause) + self._prefixes = self._prefixes.union([clause]) + + def append_whereclause(self, whereclause): + """append the given expression to this select() construct's WHERE criterion. + + The expression will be joined to existing WHERE criterion via AND. + + """ + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) + else: + self._whereclause = _literal_as_text(whereclause) + + def append_having(self, having): + """append the given expression to this select() construct's HAVING criterion. + + The expression will be joined to existing HAVING criterion via AND. + + """ + if self._having is not None: + self._having = and_(self._having, _literal_as_text(having)) + else: + self._having = _literal_as_text(having) + + def append_from(self, fromclause): + """append the given FromClause expression to this select() construct's FROM clause. + + """ + if _is_literal(fromclause): + fromclause = _TextClause(fromclause) + + self._froms = self._froms.union([fromclause]) + + def __exportable_columns(self): + for column in self._raw_columns: + if isinstance(column, Selectable): + for co in column.columns: + yield co + elif isinstance(column, ColumnElement): + yield column + else: + continue + + def _populate_column_collection(self): + for c in self.__exportable_columns(): + c._make_proxy(self, name=self.use_labels and c._label or None) + + oid_proxies = [c for c in + [f.oid_column for f in self.locate_all_froms() + if f is not self] if c is not None + ] + + if oid_proxies: + col = oid_proxies[0]._make_proxy(self) + col.proxies = oid_proxies + self._oid_column = col + + def self_group(self, against=None): + """return a 'grouping' construct as per the ClauseElement specification. + + This produces an element that can be embedded in an expression. Note that + this method is called automatically as needed when constructing expressions. + + """ + if isinstance(against, CompoundSelect): + return self + return _FromGrouping(self) + + def union(self, other, **kwargs): + """return a SQL UNION of this select() construct against the given selectable.""" + + return union(self, other, **kwargs) + + def union_all(self, other, **kwargs): + """return a SQL UNION ALL of this select() construct against the given selectable.""" + + return union_all(self, other, **kwargs) + + def except_(self, other, **kwargs): + """return a SQL EXCEPT of this select() construct against the given selectable.""" + + return except_(self, other, **kwargs) + + def except_all(self, other, **kwargs): + """return a SQL EXCEPT ALL of this select() construct against the given selectable.""" + + return except_all(self, other, **kwargs) + + def intersect(self, other, **kwargs): + """return a SQL INTERSECT of this select() construct against the given selectable.""" + + return intersect(self, other, **kwargs) + + def intersect_all(self, other, **kwargs): + """return a SQL INTERSECT ALL of this select() construct against the given selectable.""" + + return intersect_all(self, other, **kwargs) + + def _table_iterator(self): + for t in visitors.NoColumnVisitor().iterate(self): + if isinstance(t, TableClause): + yield t + + def bind(self): + if self._bind: + return self._bind + for f in self._froms: + if f is self: + continue + e = f.bind + if e: + self._bind = e + return e + # look through the columns (largely synomous with looking + # through the FROMs except in the case of _CalculatedClause/_Function) + for c in self._raw_columns: + if getattr(c, 'table', None) is self: + continue + e = c.bind + if e: + self._bind = e + return e + return None + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + +class _UpdateBase(ClauseElement): + """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" + + def supports_execution(self): + return True + + def _table_iterator(self): + return iter([self.table]) + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + def _process_colparams(self, parameters): + + if parameters is None: + return None + + if isinstance(parameters, (list, tuple)): + pp = {} + for i, c in enumerate(self.table.c): + pp[c.key] = parameters[i] + return pp + else: + return parameters + + def bind(self): + return self._bind or self.table.bind + + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + +class _ValuesBase(_UpdateBase): + def values(self, *args, **kwargs): + """specify the VALUES clause for an INSERT statement, or the SET clause for an UPDATE. + + \**kwargs + key= arguments + + \*args + deprecated. A single dictionary can be sent as the first positional argument. + """ + + if args: + v = args[0] + else: + v = {} + if len(v) == 0 and len(kwargs) == 0: + return self + u = self._clone() + + if u.parameters is None: + u.parameters = u._process_colparams(v) + u.parameters.update(kwargs) + else: + u.parameters = self.parameters.copy() + u.parameters.update(u._process_colparams(v)) + u.parameters.update(kwargs) + return u + +class Insert(_ValuesBase): + def __init__(self, table, values=None, inline=False, bind=None, prefixes=None, **kwargs): + self._bind = bind + self.table = table + self.select = None + self.inline=inline + if prefixes: + self._prefixes = [_literal_as_text(p) for p in prefixes] + else: + self._prefixes = [] + + self.parameters = self._process_colparams(values) + + self.kwargs = kwargs + + def get_children(self, **kwargs): + if self.select is not None: + return self.select, + else: + return () + + def _copy_internals(self, clone=_clone): + self.parameters = self.parameters.copy() + + def prefix_with(self, clause): + """Add a word or expression between INSERT and INTO. Generative. + + If multiple prefixes are supplied, they will be separated with + spaces. + """ + gen = self._generate() + clause = _literal_as_text(clause) + gen._prefixes = self._prefixes + [clause] + return gen + +class Update(_ValuesBase): + def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs): + self._bind = bind + self.table = table + if whereclause: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None + self.inline = inline + self.parameters = self._process_colparams(values) + + self.kwargs = kwargs + + def get_children(self, **kwargs): + if self._whereclause is not None: + return self._whereclause, + else: + return () + + def _copy_internals(self, clone=_clone): + self._whereclause = clone(self._whereclause) + self.parameters = self.parameters.copy() + + def where(self, whereclause): + """return a new update() construct with the given expression added to its WHERE clause, joined + to the existing clause via AND, if any.""" + + s = self._generate() + if s._whereclause is not None: + s._whereclause = and_(s._whereclause, _literal_as_text(whereclause)) + else: + s._whereclause = _literal_as_text(whereclause) + return s + + +class Delete(_UpdateBase): + def __init__(self, table, whereclause, bind=None): + self._bind = bind + self.table = table + if whereclause: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None + + def get_children(self, **kwargs): + if self._whereclause is not None: + return self._whereclause, + else: + return () + + def where(self, whereclause): + """return a new delete() construct with the given expression added to its WHERE clause, joined + to the existing clause via AND, if any.""" + + s = self._generate() + if s._whereclause is not None: + s._whereclause = and_(s._whereclause, _literal_as_text(whereclause)) + else: + s._whereclause = _literal_as_text(whereclause) + return s + + def _copy_internals(self, clone=_clone): + self._whereclause = clone(self._whereclause) + +class _IdentifiedClause(ClauseElement): + def __init__(self, ident): + self.ident = ident + def supports_execution(self): + return True + +class SavepointClause(_IdentifiedClause): + pass + +class RollbackToSavepointClause(_IdentifiedClause): + pass + +class ReleaseSavepointClause(_IdentifiedClause): + pass diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py new file mode 100644 index 0000000000..9e9a459765 --- /dev/null +++ b/lib/sqlalchemy/sql/functions.py @@ -0,0 +1,95 @@ +from sqlalchemy import types as sqltypes +from sqlalchemy.sql.expression import _Function, _literal_as_binds, \ + ClauseList, _FigureVisitName +from sqlalchemy.sql import operators + + +class _GenericMeta(_FigureVisitName): + def __init__(cls, clsname, bases, dict): + cls.__visit_name__ = 'function' + type.__init__(cls, clsname, bases, dict) + + def __call__(self, *args, **kwargs): + args = [_literal_as_binds(c) for c in args] + return type.__call__(self, *args, **kwargs) + +class GenericFunction(_Function): + __metaclass__ = _GenericMeta + + def __init__(self, type_=None, group=True, args=(), **kwargs): + self.packagenames = [] + self.name = self.__class__.__name__ + self._bind = kwargs.get('bind', None) + if group: + self.clause_expr = ClauseList( + operator=operators.comma_op, + group_contents=True, *args).self_group() + else: + self.clause_expr = ClauseList( + operator=operators.comma_op, + group_contents=True, *args) + self.type = sqltypes.to_instance( + type_ or getattr(self, '__return_type__', None)) + +class AnsiFunction(GenericFunction): + def __init__(self, **kwargs): + GenericFunction.__init__(self, **kwargs) + + +class coalesce(GenericFunction): + def __init__(self, *args, **kwargs): + kwargs.setdefault('type_', _type_from_args(args)) + GenericFunction.__init__(self, args=args, **kwargs) + +class now(GenericFunction): + __return_type__ = sqltypes.DateTime + +class concat(GenericFunction): + __return_type__ = sqltypes.String + def __init__(self, *args, **kwargs): + GenericFunction.__init__(self, args=args, **kwargs) + +class char_length(GenericFunction): + __return_type__ = sqltypes.Integer + + def __init__(self, arg, **kwargs): + GenericFunction.__init__(self, args=[arg], **kwargs) + +class random(GenericFunction): + def __init__(self, *args, **kwargs): + kwargs.setdefault('type_', None) + GenericFunction.__init__(self, args=args, **kwargs) + +class current_date(AnsiFunction): + __return_type__ = sqltypes.Date + +class current_time(AnsiFunction): + __return_type__ = sqltypes.Time + +class current_timestamp(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class current_user(AnsiFunction): + __return_type__ = sqltypes.String + +class localtime(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class localtimestamp(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class session_user(AnsiFunction): + __return_type__ = sqltypes.String + +class sysdate(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class user(AnsiFunction): + __return_type__ = sqltypes.String + +def _type_from_args(args): + for a in args: + if not isinstance(a.type, sqltypes.NullType): + return a.type + else: + return sqltypes.NullType diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py new file mode 100644 index 0000000000..dfd638ecb1 --- /dev/null +++ b/lib/sqlalchemy/sql/operators.py @@ -0,0 +1,119 @@ +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines operators used in SQL expressions.""" + +from operator import and_, or_, inv, add, mul, sub, div, mod, truediv, \ + lt, le, ne, gt, ge, eq +from sqlalchemy.util import Set, symbol + +def from_(): + raise NotImplementedError() + +def as_(): + raise NotImplementedError() + +def exists(): + raise NotImplementedError() + +def is_(): + raise NotImplementedError() + +def isnot(): + raise NotImplementedError() + +def collate(): + raise NotImplementedError() + +def op(a, opstring, b): + return a.op(opstring)(b) + +def like_op(a, b, escape=None): + return a.like(b, escape=escape) + +def notlike_op(a, b, escape=None): + raise NotImplementedError() + +def ilike_op(a, b, escape=None): + return a.ilike(b, escape=escape) + +def notilike_op(a, b, escape=None): + raise NotImplementedError() + +def between_op(a, b, c): + return a.between(b, c) + +def in_op(a, b): + return a.in_(*b) + +def notin_op(a, b): + raise NotImplementedError() + +def distinct_op(a): + return a.distinct() + +def startswith_op(a, b, escape=None): + return a.startswith(b, escape=escape) + +def endswith_op(a, b, escape=None): + return a.endswith(b, escape=escape) + +def contains_op(a, b, escape=None): + return a.contains(b, escape=escape) + +def comma_op(a, b): + raise NotImplementedError() + +def concat_op(a, b): + return a.concat(b) + +def desc_op(a): + return a.desc() + +def asc_op(a): + return a.asc() + +_commutative = Set([eq, ne, add, mul]) +def is_commutative(op): + return op in _commutative + +_smallest = symbol('_smallest') +_largest = symbol('_largest') + +_PRECEDENCE = { + from_:15, + mul:7, + div:7, + mod:7, + add:6, + sub:6, + concat_op:6, + ilike_op:5, + notilike_op:5, + like_op:5, + notlike_op:5, + in_op:5, + notin_op:5, + is_:5, + isnot:5, + eq:5, + ne:5, + gt:5, + lt:5, + ge:5, + le:5, + between_op:5, + distinct_op:5, + inv:5, + and_:3, + or_:2, + comma_op:-1, + collate: -2, + as_:-1, + exists:0, + _smallest: -1000, + _largest: 1000 +} + +def is_precedent(operator, against): + return _PRECEDENCE.get(operator, _PRECEDENCE[_smallest]) <= _PRECEDENCE.get(against, _PRECEDENCE[_largest]) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py new file mode 100644 index 0000000000..d299982cfa --- /dev/null +++ b/lib/sqlalchemy/sql/util.py @@ -0,0 +1,365 @@ +from sqlalchemy import exceptions, schema, topological, util, sql +from sqlalchemy.sql import expression, operators, visitors +from itertools import chain + +"""Utility functions that build upon SQL and Schema constructs.""" + +def sort_tables(tables, reverse=False): + """sort a collection of Table objects in order of their foreign-key dependency.""" + + tuples = [] + class TVisitor(schema.SchemaVisitor): + def visit_foreign_key(_self, fkey): + if fkey.use_alter: + return + parent_table = fkey.column.table + if parent_table in tables: + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) + vis = TVisitor() + for table in tables: + vis.traverse(table) + sequence = topological.sort(tuples, tables) + if reverse: + return util.reversed(sequence) + else: + return sequence + +def find_tables(clause, check_columns=False, include_aliases=False): + """locate Table objects within the given expression.""" + + tables = [] + kwargs = {} + if include_aliases: + def visit_alias(alias): + tables.append(alias) + kwargs['visit_alias'] = visit_alias + + if check_columns: + def visit_column(column): + tables.append(column.table) + kwargs['visit_column'] = visit_column + + def visit_table(table): + tables.append(table) + kwargs['visit_table'] = visit_table + + visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs) + return tables + +def find_columns(clause): + """locate Column objects within the given expression.""" + + cols = util.Set() + def visit_column(col): + cols.add(col) + visitors.traverse(clause, visit_column=visit_column) + return cols + +def join_condition(a, b, ignore_nonexistent_tables=False): + """create a join condition between two tables. + + ignore_nonexistent_tables=True allows a join condition to be + determined between two tables which may contain references to + other not-yet-defined tables. In general the NoSuchTableError + raised is only required if the user is trying to join selectables + across multiple MetaData objects (which is an extremely rare use + case). + + """ + crit = [] + constraints = util.Set() + for fk in b.foreign_keys: + try: + col = fk.get_referent(a) + except exceptions.NoReferencedTableError: + if ignore_nonexistent_tables: + continue + else: + raise + + if col: + crit.append(col == fk.parent) + constraints.add(fk.constraint) + + if a is not b: + for fk in a.foreign_keys: + try: + col = fk.get_referent(b) + except exceptions.NoReferencedTableError: + if ignore_nonexistent_tables: + continue + else: + raise + + if col: + crit.append(col == fk.parent) + constraints.add(fk.constraint) + + if len(crit) == 0: + raise exceptions.ArgumentError( + "Can't find any foreign key relationships " + "between '%s' and '%s'" % (a.description, b.description)) + elif len(constraints) > 1: + raise exceptions.ArgumentError( + "Can't determine join between '%s' and '%s'; " + "tables have more than one foreign key " + "constraint relationship between them. " + "Please specify the 'onclause' of this " + "join explicitly." % (a.description, b.description)) + elif len(crit) == 1: + return (crit[0]) + else: + return sql.and_(*crit) + + +def reduce_columns(columns, *clauses): + """given a list of columns, return a 'reduced' set based on natural equivalents. + + the set is reduced to the smallest list of columns which have no natural + equivalent present in the list. A "natural equivalent" means that two columns + will ultimately represent the same value because they are related by a foreign key. + + \*clauses is an optional list of join clauses which will be traversed + to further identify columns that are "equivalent". + + This function is primarily used to determine the most minimal "primary key" + from a selectable, by reducing the set of primary key columns present + in the the selectable to just those that are not repeated. + + """ + + columns = util.OrderedSet(columns) + + omit = util.Set() + for col in columns: + for fk in col.foreign_keys: + for c in columns: + if c is col: + continue + if fk.column.shares_lineage(c): + omit.add(col) + break + + if clauses: + def visit_binary(binary): + if binary.operator == operators.eq: + cols = util.Set(chain(*[c.proxy_set for c in columns.difference(omit)])) + if binary.left in cols and binary.right in cols: + for c in columns: + if c.shares_lineage(binary.right): + omit.add(c) + break + for clause in clauses: + visitors.traverse(clause, visit_binary=visit_binary) + + return expression.ColumnSet(columns.difference(omit)) + +def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False): + """traverse an expression and locate binary criterion pairs.""" + + if consider_as_foreign_keys and consider_as_referenced_keys: + raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'") + + def visit_binary(binary): + if not any_operator and binary.operator != operators.eq: + return + if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement): + return + + if consider_as_foreign_keys: + if binary.left in consider_as_foreign_keys: + pairs.append((binary.right, binary.left)) + elif binary.right in consider_as_foreign_keys: + pairs.append((binary.left, binary.right)) + elif consider_as_referenced_keys: + if binary.left in consider_as_referenced_keys: + pairs.append((binary.left, binary.right)) + elif binary.right in consider_as_referenced_keys: + pairs.append((binary.right, binary.left)) + else: + if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column): + if binary.left.references(binary.right): + pairs.append((binary.right, binary.left)) + elif binary.right.references(binary.left): + pairs.append((binary.left, binary.right)) + pairs = [] + visitors.traverse(expression, visit_binary=visit_binary) + return pairs + +def folded_equivalents(join, equivs=None): + """Returns the column list of the given Join with all equivalently-named, + equated columns folded into one column, where 'equated' means they are + equated to each other in the ON clause of this join. + + This function is used by Join.select(fold_equivalents=True). + + TODO: deprecate ? + """ + + if equivs is None: + equivs = util.Set() + def visit_binary(binary): + if binary.operator == operators.eq and binary.left.name == binary.right.name: + equivs.add(binary.right) + equivs.add(binary.left) + visitors.traverse(join.onclause, visit_binary=visit_binary) + collist = [] + if isinstance(join.left, expression.Join): + left = folded_equivalents(join.left, equivs) + else: + left = list(join.left.columns) + if isinstance(join.right, expression.Join): + right = folded_equivalents(join.right, equivs) + else: + right = list(join.right.columns) + used = util.Set() + for c in left + right: + if c in equivs: + if c.name not in used: + collist.append(c) + used.add(c.name) + else: + collist.append(c) + return collist + +class AliasedRow(object): + + def __init__(self, row, map): + # AliasedRow objects don't nest, so un-nest + # if another AliasedRow was passed + if isinstance(row, AliasedRow): + self.row = row.row + else: + self.row = row + self.map = map + + def __contains__(self, key): + return self.map[key] in self.row + + def has_key(self, key): + return key in self + + def __getitem__(self, key): + return self.row[self.map[key]] + + def keys(self): + return self.row.keys() + +def row_adapter(from_, equivalent_columns=None): + """create a row adapter callable against a selectable.""" + + if equivalent_columns is None: + equivalent_columns = {} + + def locate_col(col): + c = from_.corresponding_column(col) + if c: + return c + elif col in equivalent_columns: + for c2 in equivalent_columns[col]: + corr = from_.corresponding_column(c2) + if corr: + return corr + return col + + map = util.PopulateDict(locate_col) + + def adapt(row): + return AliasedRow(row, map) + return adapt + +class ColumnsInClause(visitors.ClauseVisitor): + """Given a selectable, visit clauses and determine if any columns + from the clause are in the selectable. + """ + + def __init__(self, selectable): + self.selectable = selectable + self.result = False + + def visit_column(self, column): + if self.selectable.c.get(column.key) is column: + self.result = True + +class ClauseAdapter(visitors.ClauseVisitor): + """Given a clause (like as in a WHERE criterion), locate columns + which are embedded within a given selectable, and changes those + columns to be that of the selectable. + + E.g.:: + + table1 = Table('sometable', metadata, + Column('col1', Integer), + Column('col2', Integer) + ) + table2 = Table('someothertable', metadata, + Column('col1', Integer), + Column('col2', Integer) + ) + + condition = table1.c.col1 == table2.c.col1 + + and make an alias of table1:: + + s = table1.alias('foo') + + calling ``ClauseAdapter(s).traverse(condition)`` converts + condition to read:: + + s.c.col1 == table2.c.col1 + """ + + __traverse_options__ = {'column_collections':False} + + def __init__(self, selectable, include=None, exclude=None, equivalents=None): + self.__traverse_options__ = self.__traverse_options__.copy() + self.__traverse_options__['stop_on'] = [selectable] + self.selectable = selectable + self.include = include + self.exclude = exclude + self.equivalents = equivalents + + def traverse(self, obj, clone=True): + if not clone: + raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True") + return visitors.ClauseVisitor.traverse(self, obj, clone=True) + + def copy_and_chain(self, adapter): + """create a copy of this adapter and chain to the given adapter. + + currently this adapter must be unchained to start, raises + an exception if it's already chained. + + Does not modify the given adapter. + """ + + if adapter is None: + return self + + if hasattr(self, '_next'): + raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)") + + ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents) + ca._next = adapter + return ca + + def before_clone(self, col): + if isinstance(col, expression.FromClause): + if self.selectable.is_derived_from(col): + return self.selectable + if not isinstance(col, expression.ColumnElement): + return None + if self.include is not None: + if col not in self.include: + return None + if self.exclude is not None: + if col in self.exclude: + return None + newcol = self.selectable.corresponding_column(col, require_embedded=True) + if newcol is None and self.equivalents is not None and col in self.equivalents: + for equiv in self.equivalents[col]: + newcol = self.selectable.corresponding_column(equiv, require_embedded=True) + if newcol: + return newcol + return newcol diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py new file mode 100644 index 0000000000..9888a228a3 --- /dev/null +++ b/lib/sqlalchemy/sql/visitors.py @@ -0,0 +1,182 @@ +from sqlalchemy import util + +class ClauseVisitor(object): + """Traverses and visits ``ClauseElement`` structures. + + Calls visit_XXX() methods for each particular + ``ClauseElement`` subclass encountered. Traversal of a + hierarchy of ``ClauseElements`` is achieved via the + ``traverse()`` method, which is passed the lead + ``ClauseElement``. + + By default, ``ClauseVisitor`` traverses all elements + fully. Options can be specified at the class level via the + ``__traverse_options__`` dictionary which will be passed + to the ``get_children()`` method of each ``ClauseElement``; + these options can indicate modifications to the set of + elements returned, such as to not return column collections + (column_collections=False) or to return Schema-level items + (schema_visitor=True). + + ``ClauseVisitor`` also supports a simultaneous copy-and-traverse + operation, which will produce a copy of a given ``ClauseElement`` + structure while at the same time allowing ``ClauseVisitor`` subclasses + to modify the new structure in-place. + + """ + __traverse_options__ = {} + + def traverse_single(self, obj, **kwargs): + """visit a single element, without traversing its child elements.""" + + for v in self._iterate_visitors: + meth = getattr(v, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kwargs) + + traverse_chained = traverse_single + + def iterate(self, obj): + """traverse the given expression structure, returning an iterator of all elements.""" + + stack = [obj] + traversal = util.deque() + while stack: + t = stack.pop() + traversal.appendleft(t) + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + return iter(traversal) + + def traverse(self, obj, clone=False): + """traverse and visit the given expression structure. + + Returns the structure given, or a copy of the structure if + clone=True. + + When the copy operation takes place, the before_clone() method + will receive each element before it is copied. If the method + returns a non-None value, the return value is taken as the + "copied" element and traversal will not descend further. + + The visit_XXX() methods receive the element *after* it's been + copied. To compare an element to another regardless of + one element being a cloned copy of the original, the + '_cloned_set' attribute of ClauseElement can be used for the compare, + i.e.:: + + original in copied._cloned_set + + + """ + if clone: + return self._cloned_traversal(obj) + else: + return self._non_cloned_traversal(obj) + + def copy_and_process(self, list_): + """Apply cloned traversal to the given list of elements, and return the new list.""" + + return [self._cloned_traversal(x) for x in list_] + + def before_clone(self, elem): + """receive pre-copied elements during a cloning traversal. + + If the method returns a new element, the element is used + instead of creating a simple copy of the element. Traversal + will halt on the newly returned element if it is re-encountered. + """ + return None + + def _clone_element(self, elem, stop_on, cloned): + for v in self._iterate_visitors: + newelem = v.before_clone(elem) + if newelem: + stop_on.add(newelem) + return newelem + + if elem not in cloned: + # the full traversal will only make a clone of a particular element + # once. + cloned[elem] = elem._clone() + return cloned[elem] + + def _cloned_traversal(self, obj): + """a recursive traversal which creates copies of elements, returning the new structure.""" + + stop_on = self.__traverse_options__.get('stop_on', []) + return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True) + + def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False): + if elem in stop_on: + return elem + + if _clone_toplevel: + elem = self._clone_element(elem, stop_on, cloned) + if elem in stop_on: + return elem + + def clone(element): + return self._clone_element(element, stop_on, cloned) + elem._copy_internals(clone=clone) + + self.traverse_single(elem) + + for e in elem.get_children(**self.__traverse_options__): + if e not in stop_on: + self._cloned_traversal_impl(e, stop_on, cloned) + return elem + + def _non_cloned_traversal(self, obj): + """a non-recursive, non-cloning traversal.""" + + for target in self.iterate(obj): + self.traverse_single(target) + return obj + + def _iterate_visitors(self): + """iterate through this visitor and each 'chained' visitor.""" + + v = self + while v: + yield v + v = getattr(v, '_next', None) + _iterate_visitors = property(_iterate_visitors) + + def chain(self, visitor): + """'chain' an additional ClauseVisitor onto this ClauseVisitor. + + the chained visitor will receive all visit events after this one. + """ + tail = list(self._iterate_visitors)[-1] + tail._next = visitor + return self + +class NoColumnVisitor(ClauseVisitor): + """ClauseVisitor with 'column_collections' set to False; will not + traverse the front-facing Column collections on Table, Alias, Select, + and CompoundSelect objects. + + """ + + __traverse_options__ = {'column_collections':False} + +class NullVisitor(ClauseVisitor): + def traverse(self, obj, clone=False): + next = getattr(self, '_next', None) + if next: + return next.traverse(obj, clone=clone) + else: + return obj + +def traverse(clause, **kwargs): + """traverse the given clause, applying visit functions passed in as keyword arguments.""" + + clone = kwargs.pop('clone', False) + class Vis(ClauseVisitor): + __traverse_options__ = kwargs.pop('traverse_options', {}) + vis = Vis() + for key in kwargs: + setattr(vis, key, kwargs[key]) + return vis.traverse(clause, clone=clone) + diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py deleted file mode 100644 index d91fbe4b52..0000000000 --- a/lib/sqlalchemy/sql_util.py +++ /dev/null @@ -1,241 +0,0 @@ -from sqlalchemy import sql, util, schema, topological - -"""Utility functions that build upon SQL and Schema constructs.""" - -class TableCollection(object): - def __init__(self, tables=None): - self.tables = tables or [] - - def __len__(self): - return len(self.tables) - - def __getitem__(self, i): - return self.tables[i] - - def __iter__(self): - return iter(self.tables) - - def __contains__(self, obj): - return obj in self.tables - - def __add__(self, obj): - return self.tables + list(obj) - - def add(self, table): - self.tables.append(table) - if hasattr(self, '_sorted'): - del self._sorted - - def sort(self, reverse=False): - try: - sorted = self._sorted - except AttributeError, e: - self._sorted = self._do_sort() - sorted = self._sorted - if reverse: - x = sorted[:] - x.reverse() - return x - else: - return sorted - - def _do_sort(self): - tuples = [] - class TVisitor(schema.SchemaVisitor): - def visit_foreign_key(_self, fkey): - if fkey.use_alter: - return - parent_table = fkey.column.table - if parent_table in self: - child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) - vis = TVisitor() - for table in self.tables: - vis.traverse(table) - sorter = topological.QueueDependencySorter( tuples, self.tables ) - head = sorter.sort() - sequence = [] - def to_sequence( node, seq=sequence): - seq.append( node.item ) - for child in node.children: - to_sequence( child ) - if head is not None: - to_sequence( head ) - return sequence - - -class TableFinder(TableCollection, sql.NoColumnVisitor): - """locate all Tables within a clause.""" - - def __init__(self, clause, check_columns=False, include_aliases=False): - TableCollection.__init__(self) - self.check_columns = check_columns - self.include_aliases = include_aliases - for clause in util.to_list(clause): - self.traverse(clause) - - def visit_alias(self, alias): - if self.include_aliases: - self.tables.append(alias) - - def visit_table(self, table): - self.tables.append(table) - - def visit_column(self, column): - if self.check_columns: - self.tables.append(column.table) - -class ColumnFinder(sql.ClauseVisitor): - def __init__(self): - self.columns = util.Set() - - def visit_column(self, c): - self.columns.add(c) - - def __iter__(self): - return iter(self.columns) - -class ColumnsInClause(sql.ClauseVisitor): - """Given a selectable, visit clauses and determine if any columns - from the clause are in the selectable. - """ - - def __init__(self, selectable): - self.selectable = selectable - self.result = False - - def visit_column(self, column): - if self.selectable.c.get(column.key) is column: - self.result = True - -class AbstractClauseProcessor(sql.NoColumnVisitor): - """Traverse a clause and attempt to convert the contents of container elements - to a converted element. - - The conversion operation is defined by subclasses. - """ - - def convert_element(self, elem): - """Define the *conversion* method for this ``AbstractClauseProcessor``.""" - - raise NotImplementedError() - - def copy_and_process(self, list_): - """Copy the container elements in the given list to a new list and - process the new list. - """ - - list_ = list(list_) - self.process_list(list_) - return list_ - - def process_list(self, list_): - """Process all elements of the given list in-place.""" - - for i in range(0, len(list_)): - elem = self.convert_element(list_[i]) - if elem is not None: - list_[i] = elem - else: - list_[i] = self.traverse(list_[i], clone=True) - - def visit_grouping(self, grouping): - elem = self.convert_element(grouping.elem) - if elem is not None: - grouping.elem = elem - - def visit_clauselist(self, clist): - for i in range(0, len(clist.clauses)): - n = self.convert_element(clist.clauses[i]) - if n is not None: - clist.clauses[i] = n - - def visit_unary(self, unary): - elem = self.convert_element(unary.element) - if elem is not None: - unary.element = elem - - def visit_binary(self, binary): - elem = self.convert_element(binary.left) - if elem is not None: - binary.left = elem - elem = self.convert_element(binary.right) - if elem is not None: - binary.right = elem - - def visit_select(self, select): - fr = util.OrderedSet() - for elem in select._froms: - n = self.convert_element(elem) - if n is not None: - fr.add((elem, n)) - select._recorrelate_froms(fr) - - col = [] - for elem in select._raw_columns: - print "RAW COLUMN", elem - n = self.convert_element(elem) - if n is None: - col.append(elem) - else: - col.append(n) - select._raw_columns = col - -class ClauseAdapter(AbstractClauseProcessor): - """Given a clause (like as in a WHERE criterion), locate columns - which are embedded within a given selectable, and changes those - columns to be that of the selectable. - - E.g.:: - - table1 = Table('sometable', metadata, - Column('col1', Integer), - Column('col2', Integer) - ) - table2 = Table('someothertable', metadata, - Column('col1', Integer), - Column('col2', Integer) - ) - - condition = table1.c.col1 == table2.c.col1 - - and make an alias of table1:: - - s = table1.alias('foo') - - calling ``ClauseAdapter(s).traverse(condition)`` converts - condition to read:: - - s.c.col1 == table2.c.col1 - """ - - def __init__(self, selectable, include=None, exclude=None, equivalents=None): - self.selectable = selectable - self.include = include - self.exclude = exclude - self.equivalents = equivalents - - def convert_element(self, col): - if isinstance(col, sql.FromClause): - if self.selectable.is_derived_from(col): - return self.selectable - if not isinstance(col, sql.ColumnElement): - return None - if self.include is not None: - if col not in self.include: - return None - if self.exclude is not None: - if col in self.exclude: - return None - newcol = self.selectable.corresponding_column(col, raiseerr=False, require_embedded=True, keys_ok=False) - if newcol is None and self.equivalents is not None and col in self.equivalents: - for equiv in self.equivalents[col]: - newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False) - if newcol: - return newcol - #if newcol is None: - # self.traverse(col) - # return col - return newcol - - diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index 56c8cb46e4..9961239798 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -1,23 +1,11 @@ # topological.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """Topological sorting algorithms. -The key to the unit of work is to assemble a list of dependencies -amongst all the different mappers that have been defined for classes. - -Related tables with foreign key constraints have a definite insert -order, deletion order, objects need dependent properties from parent -objects set up before saved, etc. - -These are all encoded as dependencies, in the form *mapper X is -dependent on mapper Y*, meaning mapper Y's objects must be saved -before those of mapper X, and mapper X's objects must be deleted -before those of mapper Y. - The topological sort is an algorithm that receives this list of dependencies as a *partial ordering*, that is a list of pairs which might say, *X is dependent on Y*, *Q is dependent on Z*, but does not @@ -28,30 +16,51 @@ then only towards just some of the other elements. For a particular partial ordering, there can be many possible sorts that satisfy the conditions. -An intrinsic *gotcha* to this algorithm is that since there are many -possible outcomes to sorting a partial ordering, the algorithm can -return any number of different results for the same input; just -running it on a different machine architecture, or just random -differences in the ordering of dictionaries, can change the result -that is returned. While this result is guaranteed to be true to the -incoming partial ordering, if the partial ordering itself does not -properly represent the dependencies, code that works fine will -suddenly break, then work again, then break, etc. Most of the bugs -I've chased down while developing the *unit of work* have been of this -nature - very tricky to reproduce and track down, particularly before -I realized this characteristic of the algorithm. """ from sqlalchemy import util from sqlalchemy.exceptions import CircularDependencyError -class _Node(object): - """Represent each item in the sort. +__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree'] - While the topological sort produces a straight ordered list of - items, ``_Node`` ultimately stores a tree-structure of those items - which are organized so that non-dependent nodes are siblings. +def sort(tuples, allitems): + """sort the given list of items by dependency. + + 'tuples' is a list of tuples representing a partial ordering. + """ + + return [n.item for n in _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=True)] + +def sort_with_cycles(tuples, allitems): + """sort the given list of items by dependency, cutting out cycles. + + returns results as an iterable of 2-tuples, containing the item, + and a list containing items involved in a cycle with this item, if any. + + 'tuples' is a list of tuples representing a partial ordering. """ + + return [(n.item, [n.item for n in n.cycles or []]) for n in _sort(tuples, allitems, allow_cycles=True)] + +def sort_as_tree(tuples, allitems, with_cycles=False): + """sort the given list of items by dependency, and return results + as a hierarchical tree structure. + + returns results as an iterable of 3-tuples, containing the item, + and a list containing items involved in a cycle with this item, if any, + and a list of child tuples. + + if with_cycles is False, the returned structure is of the same form + but the second element of each tuple, i.e. the 'cycles', is an empty list. + + 'tuples' is a list of tuples representing a partial ordering. + """ + + return _organize_as_tree(_sort(tuples, allitems, allow_cycles=with_cycles)) + + +class _Node(object): + """Represent each item in the sort.""" def __init__(self, item): self.item = item @@ -61,13 +70,13 @@ class _Node(object): def __str__(self): return self.safestr() - + def safestr(self, indent=0): return (' ' * indent * 2) + \ str(self.item) + \ (self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \ "\n" + \ - ''.join([n.safestr(indent + 1) for n in self.children]) + ''.join([str(n) for n in self.children]) def __repr__(self): return "%s" % (str(self.item)) @@ -92,10 +101,10 @@ class _EdgeCollection(object): """Add an edge to this collection.""" (parentnode, childnode) = edge - if not self.parent_to_children.has_key(parentnode): + if parentnode not in self.parent_to_children: self.parent_to_children[parentnode] = util.Set() self.parent_to_children[parentnode].add(childnode) - if not self.child_to_parents.has_key(childnode): + if childnode not in self.child_to_parents: self.child_to_parents[childnode] = util.Set() self.child_to_parents[childnode].add(parentnode) parentnode.dependencies.add(childnode) @@ -115,10 +124,10 @@ class _EdgeCollection(object): return None def has_parents(self, node): - return self.child_to_parents.has_key(node) and len(self.child_to_parents[node]) > 0 + return node in self.child_to_parents and len(self.child_to_parents[node]) > 0 def edges_by_parent(self, node): - if self.parent_to_children.has_key(node): + if node in self.parent_to_children: return [(node, child) for child in self.parent_to_children[node]] else: return [] @@ -137,7 +146,7 @@ class _EdgeCollection(object): if children is not None: for child in children: self.child_to_parents[child].remove(node) - if not len(self.child_to_parents[child]): + if not self.child_to_parents[child]: yield child def __len__(self): @@ -154,162 +163,146 @@ class _EdgeCollection(object): def __repr__(self): return repr(list(self)) -class QueueDependencySorter(object): - """Topological sort adapted from wikipedia's article on the subject. - - It creates a straight-line list of elements, then a second pass - groups non-dependent actions together to build more of a tree - structure with siblings. +def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False): + nodes = {} + edges = _EdgeCollection() + + for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]: + if id(item) not in nodes: + node = _Node(item) + nodes[item] = node + + for t in tuples: + if t[0] is t[1]: + if allow_cycles: + n = nodes[t[0]] + n.cycles = util.Set([n]) + elif not ignore_self_cycles: + raise CircularDependencyError("Self-referential dependency detected " + repr(t)) + continue + childnode = nodes[t[1]] + parentnode = nodes[t[0]] + edges.add((parentnode, childnode)) + + queue = [] + for n in nodes.values(): + if not edges.has_parents(n): + queue.append(n) + + output = [] + while nodes: + if not queue: + # edges remain but no edgeless nodes to remove; this indicates + # a cycle + if allow_cycles: + for cycle in _find_cycles(edges): + lead = cycle[0][0] + lead.cycles = util.Set() + for edge in cycle: + n = edges.remove(edge) + lead.cycles.add(edge[0]) + lead.cycles.add(edge[1]) + if n is not None: + queue.append(n) + for n in lead.cycles: + if n is not lead: + n._cyclical = True + for (n,k) in list(edges.edges_by_parent(n)): + edges.add((lead, k)) + edges.remove((n,k)) + continue + else: + # long cycles not allowed + raise CircularDependencyError("Circular dependency detected " + repr(edges) + repr(queue)) + node = queue.pop() + if not hasattr(node, '_cyclical'): + output.append(node) + del nodes[node.item] + for childnode in edges.pop_node(node): + queue.append(childnode) + return output + +def _organize_as_tree(nodes): + """Given a list of nodes from a topological sort, organize the + nodes into a tree structure, with as many non-dependent nodes + set as siblings to each other as possible. + + returns nodes as 3-tuples (item, cycles, children). """ - def __init__(self, tuples, allitems): - self.tuples = tuples - self.allitems = allitems - - def sort(self, allow_self_cycles=True, allow_all_cycles=False): - (tuples, allitems) = (self.tuples, self.allitems) - #print "\n---------------------------------\n" - #print repr([t for t in tuples]) - #print repr([a for a in allitems]) - #print "\n---------------------------------\n" - - nodes = {} - edges = _EdgeCollection() - for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]: - if not nodes.has_key(item): - node = _Node(item) - nodes[item] = node - - for t in tuples: - if t[0] is t[1]: - if allow_self_cycles: - n = nodes[t[0]] - n.cycles = util.Set([n]) - continue - else: - raise CircularDependencyError("Self-referential dependency detected " + repr(t)) - childnode = nodes[t[1]] - parentnode = nodes[t[0]] - edges.add((parentnode, childnode)) - - queue = [] - for n in nodes.values(): - if not edges.has_parents(n): - queue.append(n) - cycles = {} - output = [] - while len(nodes) > 0: - if len(queue) == 0: - # edges remain but no edgeless nodes to remove; this indicates - # a cycle - if allow_all_cycles: - for cycle in self._find_cycles(edges): - lead = cycle[0][0] - lead.cycles = util.Set() - for edge in cycle: - n = edges.remove(edge) - lead.cycles.add(edge[0]) - lead.cycles.add(edge[1]) - if n is not None: - queue.append(n) - for n in lead.cycles: - if n is not lead: - n._cyclical = True - for (n,k) in list(edges.edges_by_parent(n)): - edges.add((lead, k)) - edges.remove((n,k)) - continue - else: - # long cycles not allowed - raise CircularDependencyError("Circular dependency detected " + repr(edges) + repr(queue)) - node = queue.pop() - if not hasattr(node, '_cyclical'): - output.append(node) - del nodes[node.item] - for childnode in edges.pop_node(node): - queue.append(childnode) - return self._create_batched_tree(output) - - - def _create_batched_tree(self, nodes): - """Given a list of nodes from a topological sort, organize the - nodes into a tree structure, with as many non-dependent nodes - set as siblings to each other as possible. - """ - - if not len(nodes): - return None - # a list of all currently independent subtrees as a tuple of - # (root_node, set_of_all_tree_nodes, set_of_all_cycle_nodes_in_tree) - # order of the list has no semantics for the algorithmic - independents = [] - # in reverse topological order - for node in util.reversed(nodes): - # nodes subtree and cycles contain the node itself - subtree = util.Set([node]) - if node.cycles is not None: - cycles = util.Set(node.cycles) - else: - cycles = util.Set() - # get a set of dependent nodes of node and its cycles - nodealldeps = node.all_deps() - if nodealldeps: - # iterate over independent node indexes in reverse order so we can efficiently remove them - for index in xrange(len(independents)-1,-1,-1): - child, childsubtree, childcycles = independents[index] - # if there is a dependency between this node and an independent node - if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)): - # prepend child to nodes children - # (append should be fine, but previous implemetation used prepend) - node.children[0:0] = (child,) - # merge childs subtree and cycles - subtree.update(childsubtree) - cycles.update(childcycles) - # remove the child from list of independent subtrees - independents[index:index+1] = [] - # add node as a new independent subtree - independents.append((node,subtree,cycles)) - # choose an arbitrary node from list of all independent subtrees - head = independents.pop()[0] - # add all other independent subtrees as a child of the chosen root - # used prepend [0:0] instead of extend to maintain exact behaviour of previous implementation - head.children[0:0] = [i[0] for i in independents] - return head - - def _find_cycles(self, edges): - involved_in_cycles = util.Set() - cycles = {} - def traverse(node, goal=None, cycle=None): - if goal is None: - goal = node - cycle = [] - elif node is goal: - return True - - for (n, key) in edges.edges_by_parent(node): - if key in cycle: - continue - cycle.append(key) - if traverse(key, goal, cycle): - cycset = util.Set(cycle) - for x in cycle: - involved_in_cycles.add(x) - if cycles.has_key(x): - existing_set = cycles[x] - [existing_set.add(y) for y in cycset] - for y in existing_set: - cycles[y] = existing_set - cycset = existing_set - else: - cycles[x] = cycset - cycle.pop() - - for parent in edges.get_parents(): - traverse(parent) - - for cycle in dict([(id(s), s) for s in cycles.values()]).values(): - edgecollection = [] - for edge in edges: - if edge[0] in cycle and edge[1] in cycle: - edgecollection.append(edge) - yield edgecollection + if not nodes: + return None + # a list of all currently independent subtrees as a tuple of + # (root_node, set_of_all_tree_nodes, set_of_all_cycle_nodes_in_tree) + # order of the list has no semantics for the algorithmic + independents = [] + # in reverse topological order + for node in util.reversed(nodes): + # nodes subtree and cycles contain the node itself + subtree = util.Set([node]) + if node.cycles is not None: + cycles = util.Set(node.cycles) + else: + cycles = util.Set() + # get a set of dependent nodes of node and its cycles + nodealldeps = node.all_deps() + if nodealldeps: + # iterate over independent node indexes in reverse order so we can efficiently remove them + for index in xrange(len(independents)-1,-1,-1): + child, childsubtree, childcycles = independents[index] + # if there is a dependency between this node and an independent node + if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)): + # prepend child to nodes children + # (append should be fine, but previous implemetation used prepend) + node.children[0:0] = [(child.item, [n.item for n in child.cycles or []], child.children)] + # merge childs subtree and cycles + subtree.update(childsubtree) + cycles.update(childcycles) + # remove the child from list of independent subtrees + independents[index:index+1] = [] + # add node as a new independent subtree + independents.append((node,subtree,cycles)) + # choose an arbitrary node from list of all independent subtrees + head = independents.pop()[0] + # add all other independent subtrees as a child of the chosen root + # used prepend [0:0] instead of extend to maintain exact behaviour of previous implementation + head.children[0:0] = [(i[0].item, [n.item for n in i[0].cycles or []], i[0].children) for i in independents] + return (head.item, [n.item for n in head.cycles or []], head.children) + +def _find_cycles(edges): + involved_in_cycles = util.Set() + cycles = {} + def traverse(node, goal=None, cycle=None): + if goal is None: + goal = node + cycle = [] + elif node is goal: + return True + + for (n, key) in edges.edges_by_parent(node): + if key in cycle: + continue + cycle.append(key) + if traverse(key, goal, cycle): + cycset = util.Set(cycle) + for x in cycle: + involved_in_cycles.add(x) + if x in cycles: + existing_set = cycles[x] + [existing_set.add(y) for y in cycset] + for y in existing_set: + cycles[y] = existing_set + cycset = existing_set + else: + cycles[x] = cycset + cycle.pop() + + for parent in edges.get_parents(): + traverse(parent) + + # sets are not hashable, so uniquify with id + unique_cycles = dict([(id(s), s) for s in cycles.values()]).values() + for cycle in unique_cycles: + edgecollection = [edge for edge in edges + if edge[0] in cycle and edge[1] in cycle] + yield edgecollection diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index ec14598520..e06ec9a5a5 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -1,52 +1,162 @@ # types.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -__all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine', - 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'FLOAT', 'DECIMAL', - 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN', 'String', 'Integer', 'SmallInteger','Smallinteger', - 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'PickleType', 'NULLTYPE', 'NullType', - 'SMALLINT', 'DATE', 'TIME','Interval' +"""defines genericized SQL types, each represented by a subclass of +[sqlalchemy.types#AbstractType]. Dialects define further subclasses of these +types. + +For more information see the SQLAlchemy documentation on types. + +""" +__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', + 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT', + 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', + 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME', + 'String', 'Integer', 'SmallInteger','Smallinteger', + 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', + 'Boolean', 'Unicode', 'UnicodeText', 'PickleType', 'Interval', + 'type_map' ] import inspect import datetime as dt -from decimal import Decimal -try: - import cPickle as pickle -except: - import pickle from sqlalchemy import exceptions +from sqlalchemy.util import pickle, Decimal as _python_Decimal +import sqlalchemy.util as util +NoneType = type(None) + +class _UserTypeAdapter(type): + """adapts 0.3 style user-defined types with convert_bind_param/convert_result_value + to use newer bind_processor()/result_processor() methods.""" + + def __init__(cls, clsname, bases, dict): + if not hasattr(cls.convert_result_value, '_sa_override'): + cls.__instrument_result_proc(cls) + + if not hasattr(cls.convert_bind_param, '_sa_override'): + cls.__instrument_bind_proc(cls) + + return super(_UserTypeAdapter, cls).__init__(clsname, bases, dict) + + def __instrument_bind_proc(cls, class_): + def bind_processor(self, dialect): + def process(value): + return self.convert_bind_param(value, dialect) + return process + class_.super_bind_processor = class_.bind_processor + class_.bind_processor = bind_processor + + def __instrument_result_proc(cls, class_): + def result_processor(self, dialect): + def process(value): + return self.convert_result_value(value, dialect) + return process + class_.super_result_processor = class_.result_processor + class_.result_processor = result_processor + class AbstractType(object): + __metaclass__ = _UserTypeAdapter + def __init__(self, *args, **kwargs): pass - + def copy_value(self, value): return value + def convert_result_value(self, value, dialect): + """Legacy convert_result_value() compatibility method. + + This adapter method is provided for user-defined types that implement + the older convert_* interface and need to call their super method. + These calls are adapted behind the scenes to use the newer + callable-based interface via result_processor(). + + Compatibility is configured on a case-by-case basis at class + definition time by a legacy adapter metaclass. This method is only + available and functional if the concrete subclass implements the + legacy interface. + """ + + processor = self.super_result_processor(dialect) + if processor: + return processor(value) + else: + return value + convert_result_value._sa_override = True + + def convert_bind_param(self, value, dialect): + """Legacy convert_bind_param() compatability method. + + This adapter method is provided for user-defined types that implement + the older convert_* interface and need to call their super method. + These calls are adapted behind the scenes to use the newer + callable-based interface via bind_processor(). + + Compatibility is configured on a case-by-case basis at class + definition time by a legacy adapter metaclass. This method is only + available and functional if the concrete subclass implements the + legacy interface. + """ + + processor = self.super_bind_processor(dialect) + if processor: + return processor(value) + else: + return value + convert_bind_param._sa_override = True + + def bind_processor(self, dialect): + """Defines a bind parameter processing function.""" + + return None + + def result_processor(self, dialect): + """Defines a result-column processing function.""" + + return None + def compare_values(self, x, y): + """compare two values for equality.""" + return x == y def is_mutable(self): + """return True if the target Python type is 'mutable'. + + This allows systems like the ORM to know if an object + can be considered 'not changed' by identity alone. + """ + return False def get_dbapi_type(self, dbapi): - """Return the corresponding type object from the underlying DBAPI, if any. + """Return the corresponding type object from the underlying DB-API, if any. This can be useful for calling ``setinputsizes()``, for example. """ return None + def adapt_operator(self, op): + """given an operator from the sqlalchemy.sql.operators package, + translate it to a new operator based on the semantics of this type. + + By default, returns the operator unchanged.""" + return op + def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]])) + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join(["%s=%r" % (k, getattr(self, k, None)) + for k in inspect.getargspec(self.__init__)[0][1:]])) class TypeEngine(AbstractType): - def dialect_impl(self, dialect): + def dialect_impl(self, dialect, **kwargs): try: return self._impl_dict[dialect] except AttributeError: @@ -54,39 +164,60 @@ class TypeEngine(AbstractType): return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self)) except KeyError: return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self)) - + def __getstate__(self): d = self.__dict__.copy() d['_impl_dict'] = {} return d - + def get_col_spec(self): raise NotImplementedError() - def convert_bind_param(self, value, dialect): - return value - def convert_result_value(self, value, dialect): - return value + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + return None def adapt(self, cls): return cls() - + def get_search_list(self): - """return a list of classes to test for a match + """return a list of classes to test for a match when adapting this type to a dialect-specific type. - + """ - + return self.__class__.__mro__[0:-1] - + class TypeDecorator(AbstractType): + """Allows the creation of types which add additional functionality + to an existing type. Typical usage:: + + class MyCustomType(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + return value + "incoming string" + + def process_result_value(self, value, dialect): + return value[0:-16] + + The class-level "impl" variable is required, and can reference any + TypeEngine class. Alternatively, the load_dialect_impl() method can + be used to provide different type classes based on the dialect given; + in this case, the "impl" variable can reference ``TypeEngine`` as a + placeholder. + + """ + def __init__(self, *args, **kwargs): if not hasattr(self.__class__, 'impl'): raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated") self.impl = self.__class__.impl(*args, **kwargs) - def dialect_impl(self, dialect): + def dialect_impl(self, dialect, **kwargs): try: return self._impl_dict[dialect] except AttributeError: @@ -94,7 +225,10 @@ class TypeDecorator(AbstractType): except KeyError: pass - typedesc = self.load_dialect_impl(dialect) + if isinstance(self.impl, TypeDecorator): + typedesc = self.impl.dialect_impl(dialect) + else: + typedesc = self.load_dialect_impl(dialect) tt = self.copy() if not isinstance(tt, self.__class__): raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__)) @@ -104,13 +238,13 @@ class TypeDecorator(AbstractType): def load_dialect_impl(self, dialect): """loads the dialect-specific implementation of this type. - + by default calls dialect.type_descriptor(self.impl), but can be overridden to provide different behavior. """ return dialect.type_descriptor(self.impl) - + def __getattr__(self, key): """Proxy all other undefined accessors to the underlying implementation.""" @@ -119,11 +253,39 @@ class TypeDecorator(AbstractType): def get_col_spec(self): return self.impl.get_col_spec() - def convert_bind_param(self, value, dialect): - return self.impl.convert_bind_param(value, dialect) + def process_bind_param(self, value, dialect): + raise NotImplementedError() - def convert_result_value(self, value, dialect): - return self.impl.convert_result_value(value, dialect) + def process_result_value(self, value, dialect): + raise NotImplementedError() + + def bind_processor(self, dialect): + if self.__class__.process_bind_param.func_code is not TypeDecorator.process_bind_param.func_code: + impl_processor = self.impl.bind_processor(dialect) + if impl_processor: + def process(value): + return impl_processor(self.process_bind_param(value, dialect)) + return process + else: + def process(value): + return self.process_bind_param(value, dialect) + return process + else: + return self.impl.bind_processor(dialect) + + def result_processor(self, dialect): + if self.__class__.process_result_value.func_code is not TypeDecorator.process_result_value.func_code: + impl_processor = self.impl.result_processor(dialect) + if impl_processor: + def process(value): + return self.process_result_value(impl_processor(value), dialect) + return process + else: + def process(value): + return self.process_result_value(value, dialect) + return process + else: + return self.impl.result_processor(dialect) def copy(self): instance = self.__class__.__new__(self.__class__) @@ -187,52 +349,111 @@ class NullType(TypeEngine): def get_col_spec(self): raise NotImplementedError() - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - return value NullTypeEngine = NullType class Concatenable(object): """marks a type as supporting 'concatenation'""" - pass - -class String(TypeEngine, Concatenable): - def __init__(self, length=None, convert_unicode=False): + def adapt_operator(self, op): + from sqlalchemy.sql import operators + if op == operators.add: + return operators.concat_op + else: + return op + +class String(Concatenable, TypeEngine): + """A sized string type. + + Usually corresponds to VARCHAR. Can also take Python unicode objects + and encode to the database's encoding in bind params (and the reverse for + result sets.) + + a String with no length will adapt itself automatically to a Text + object at the dialect level (this behavior is deprecated in 0.4). + """ + def __init__(self, length=None, convert_unicode=False, assert_unicode=None): self.length = length self.convert_unicode = convert_unicode + self.assert_unicode = assert_unicode def adapt(self, impltype): - return impltype(length=self.length, convert_unicode=self.convert_unicode) + return impltype(length=self.length, convert_unicode=self.convert_unicode, assert_unicode=self.assert_unicode) + + def bind_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + if self.assert_unicode is None: + assert_unicode = dialect.assert_unicode + else: + assert_unicode = self.assert_unicode + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + elif assert_unicode and not isinstance(value, (unicode, NoneType)): + if assert_unicode == 'warn': + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + else: + raise exceptions.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) + else: + return value + return process + else: + return None - def convert_bind_param(self, value, dialect): - if not (self.convert_unicode or dialect.convert_unicode) or value is None or not isinstance(value, unicode): - return value + def result_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + def process(value): + if value is not None and not isinstance(value, unicode): + return value.decode(dialect.encoding) + else: + return value + return process else: - return value.encode(dialect.encoding) + return None + + def dialect_impl(self, dialect, **kwargs): + _for_ddl = kwargs.pop('_for_ddl', False) + if _for_ddl and self.length is None: + label = util.to_ascii(_for_ddl is True and + '' or (' for column "%s"' % str(_for_ddl))) + util.warn_deprecated( + "Using String type with no length for CREATE TABLE " + "is deprecated; use the Text type explicitly" + label) + return TypeEngine.dialect_impl(self, dialect, **kwargs) def get_search_list(self): l = super(String, self).get_search_list() - if self.length is None: - return (TEXT,) + l + # if we are String or Unicode with no length, + # return Text as the highest-priority type + # to be adapted by the dialect + if self.length is None and l[0] in (String, Unicode): + return (Text,) + l else: return l - def convert_result_value(self, value, dialect): - if not (self.convert_unicode or dialect.convert_unicode) or value is None or isinstance(value, unicode): - return value - else: - return value.decode(dialect.encoding) - def get_dbapi_type(self, dbapi): return dbapi.STRING +class Text(String): + def dialect_impl(self, dialect, **kwargs): + return TypeEngine.dialect_impl(self, dialect, **kwargs) + class Unicode(String): + """A synonym for String(length, convert_unicode=True, assert_unicode='warn').""" + def __init__(self, length=None, **kwargs): kwargs['convert_unicode'] = True + kwargs['assert_unicode'] = 'warn' super(Unicode, self).__init__(length=length, **kwargs) - + +class UnicodeText(Text): + """A synonym for Text(convert_unicode=True, assert_unicode='warn').""" + + def __init__(self, length=None, **kwargs): + kwargs['convert_unicode'] = True + kwargs['assert_unicode'] = 'warn' + super(UnicodeText, self).__init__(length=length, **kwargs) + class Integer(TypeEngine): """Integer datatype.""" @@ -242,12 +463,12 @@ class Integer(TypeEngine): class SmallInteger(Integer): """Smallint datatype.""" - pass - Smallinteger = SmallInteger class Numeric(TypeEngine): - def __init__(self, precision = 10, length = 2, asdecimal=True): + """Numeric datatype, usually resolves to DECIMAL or NUMERIC.""" + + def __init__(self, precision=10, length=2, asdecimal=True): self.precision = precision self.length = length self.asdecimal = asdecimal @@ -258,23 +479,30 @@ class Numeric(TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.NUMBER - def convert_bind_param(self, value, dialect): - if value is not None: - return float(value) + def bind_processor(self, dialect): + def process(value): + if value is not None: + return float(value) + else: + return value + return process + + def result_processor(self, dialect): + if self.asdecimal: + def process(value): + if value is not None: + return _python_Decimal(str(value)) + else: + return value + return process else: - return value - - def convert_result_value(self, value, dialect): - if value is not None and self.asdecimal: - return Decimal(str(value)) - else: - return value + return None class Float(Numeric): def __init__(self, precision = 10, asdecimal=False, **kwargs): self.precision = precision self.asdecimal = asdecimal - + def adapt(self, impltype): return impltype(precision=self.precision, asdecimal=self.asdecimal) @@ -312,14 +540,14 @@ class Binary(TypeEngine): def __init__(self, length=None): self.length = length - def convert_bind_param(self, value, dialect): - if value is not None: - return dialect.dbapi.Binary(value) - else: - return None - - def convert_result_value(self, value, dialect): - return value + def bind_processor(self, dialect): + DBAPIBinary = dialect.dbapi.Binary + def process(value): + if value is not None: + return DBAPIBinary(value) + else: + return None + return process def adapt(self, impltype): return impltype(length=self.length) @@ -330,22 +558,25 @@ class Binary(TypeEngine): class PickleType(MutableType, TypeDecorator): impl = Binary - def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, mutable=True): + def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, mutable=True, comparator=None): self.protocol = protocol self.pickler = pickler or pickle self.mutable = mutable + self.comparator = comparator super(PickleType, self).__init__() - def convert_result_value(self, value, dialect): + def process_bind_param(self, value, dialect): + dumps = self.pickler.dumps + protocol = self.protocol if value is None: return None - buf = self.impl.convert_result_value(value, dialect) - return self.pickler.loads(str(buf)) + return dumps(value, protocol) - def convert_bind_param(self, value, dialect): + def process_result_value(self, value, dialect): + loads = self.pickler.loads if value is None: return None - return self.impl.convert_bind_param(self.pickler.dumps(value, self.protocol), dialect) + return loads(str(value)) def copy_value(self, value): if self.mutable: @@ -354,17 +585,19 @@ class PickleType(MutableType, TypeDecorator): return value def compare_values(self, x, y): - if self.mutable: + if self.comparator: + return self.comparator(x, y) + elif self.mutable: return self.pickler.dumps(x, self.protocol) == self.pickler.dumps(y, self.protocol) else: - return x is y + return x == y def is_mutable(self): return self.mutable class Boolean(TypeEngine): pass - + class Interval(TypeDecorator): """Type to be used in Column statements to store python timedeltas. @@ -374,66 +607,70 @@ class Interval(TypeDecorator): Converting is very simple - just use epoch(zero timestamp, 01.01.1970) as base, so if we need to store timedelta = 1 day (24 hours) in database it - will be stored as DateTime = '2nd Jan 1970 00:00', see convert_bind_param - and convert_result_value to actual conversion code + will be stored as DateTime = '2nd Jan 1970 00:00', see bind_processor + and result_processor to actual conversion code """ - #Empty useless type, because at the moment of creation of instance we don't - #know what type will be decorated - it depends on used dialect. + impl = TypeEngine + def __init__(self): + super(Interval, self).__init__() + import sqlalchemy.databases.postgres as pg + self.__supported = {pg.PGDialect:pg.PGInterval} + del pg + def load_dialect_impl(self, dialect): - """Checks if engine has native implementation of timedelta python type, - if so it returns right class to handle it, if there is no native support, - it fallback to engine's DateTime implementation class - """ - if not hasattr(self,'__supported'): - import sqlalchemy.databases.postgres as pg - self.__supported = {pg.PGDialect:pg.PGInterval} - del pg - - if self.__hasNativeImpl(dialect): - #For now, only PostgreSQL has native timedelta types support + if dialect.__class__ in self.__supported: return self.__supported[dialect.__class__]() else: - #All others should fallback to DateTime return dialect.type_descriptor(DateTime) - - def __hasNativeImpl(self,dialect): - return dialect.__class__ in self.__supported - - def convert_bind_param(self, value, dialect): - if value is None: - return None - if not self.__hasNativeImpl(dialect): - tmpval = dt.datetime.utcfromtimestamp(0) + value - return self.impl.convert_bind_param(tmpval,dialect) - else: - return self.impl.convert_bind_param(value,dialect) - def convert_result_value(self, value, dialect): - if value is None: - return None - retval = self.impl.convert_result_value(value,dialect) - if not self.__hasNativeImpl(dialect): - return retval - dt.datetime.utcfromtimestamp(0) + def process_bind_param(self, value, dialect): + if dialect.__class__ in self.__supported: + return value else: - return retval - -class FLOAT(Float):pass -class TEXT(String):pass -class DECIMAL(Numeric):pass -class INT(Integer):pass + if value is None: + return None + return dt.datetime.utcfromtimestamp(0) + value + + def process_result_value(self, value, dialect): + if dialect.__class__ in self.__supported: + return value + else: + if value is None: + return None + return value - dt.datetime.utcfromtimestamp(0) + +class FLOAT(Float): pass +TEXT = Text +class NUMERIC(Numeric): pass +class DECIMAL(Numeric): pass +class INT(Integer): pass INTEGER = INT -class SMALLINT(Smallinteger):pass +class SMALLINT(Smallinteger): pass class TIMESTAMP(DateTime): pass class DATETIME(DateTime): pass class DATE(Date): pass class TIME(Time): pass -class CLOB(TEXT): pass +class CLOB(Text): pass class VARCHAR(String): pass -class CHAR(String):pass -class NCHAR(Unicode):pass +class CHAR(String): pass +class NCHAR(Unicode): pass class BLOB(Binary): pass class BOOLEAN(Boolean): pass NULLTYPE = NullType() + +# using VARCHAR/NCHAR so that we dont get the genericized "String" +# type which usually resolves to TEXT/CLOB +type_map = { + str : VARCHAR, + unicode : NCHAR, + int : Integer, + float : Numeric, + dt.date : Date, + dt.datetime : DateTime, + dt.time : Time, + dt.timedelta : Interval, + type(None): NullType +} diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index e711de3a3e..e88c4b3b9b 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -1,30 +1,77 @@ # util.py -# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import inspect, itertools, new, operator, sets, sys, warnings, weakref +import __builtin__ +types = __import__('types') + +from sqlalchemy import exceptions + try: import thread, threading except ImportError: import dummy_thread as thread import dummy_threading as threading -from sqlalchemy import exceptions -import md5 -import sys -import warnings -import __builtin__ - try: Set = set -except: - import sets - Set = sets.Set + set_types = set, sets.Set +except NameError: + set_types = sets.Set, + # layer some of __builtin__.set's binop behavior onto sets.Set + class Set(sets.Set): + def _binary_sanity_check(self, other): + pass + + def issubset(self, iterable): + other = type(self)(iterable) + return sets.Set.issubset(self, other) + def __le__(self, other): + sets.Set._binary_sanity_check(self, other) + return sets.Set.__le__(self, other) + def issuperset(self, iterable): + other = type(self)(iterable) + return sets.Set.issuperset(self, other) + def __ge__(self, other): + sets.Set._binary_sanity_check(self, other) + return sets.Set.__ge__(self, other) + + # lt and gt still require a BaseSet + def __lt__(self, other): + sets.Set._binary_sanity_check(self, other) + return sets.Set.__lt__(self, other) + def __gt__(self, other): + sets.Set._binary_sanity_check(self, other) + return sets.Set.__gt__(self, other) + + def __ior__(self, other): + if not isinstance(other, sets.BaseSet): + return NotImplemented + return sets.Set.__ior__(self, other) + def __iand__(self, other): + if not isinstance(other, sets.BaseSet): + return NotImplemented + return sets.Set.__iand__(self, other) + def __ixor__(self, other): + if not isinstance(other, sets.BaseSet): + return NotImplemented + return sets.Set.__ixor__(self, other) + def __isub__(self, other): + if not isinstance(other, sets.BaseSet): + return NotImplemented + return sets.Set.__isub__(self, other) + +try: + import cPickle as pickle +except ImportError: + import pickle try: reversed = __builtin__.reversed -except: +except AttributeError: def reversed(seq): i = len(seq) -1 while i >= 0: @@ -32,13 +79,34 @@ except: i -= 1 raise StopIteration() +try: + # Try the standard decimal for > 2.3 or the compatibility module + # for 2.3, if installed. + from decimal import Decimal + decimal_type = Decimal +except ImportError: + def Decimal(arg): + if Decimal.warn: + warn("True Decimal types not available on this Python, " + "falling back to floats.") + Decimal.warn = False + return float(arg) + Decimal.warn = True + decimal_type = float + +try: + from operator import attrgetter +except: + def attrgetter(attribute): + return lambda value: getattr(value, attribute) + if sys.version_info >= (2, 5): class PopulateDict(dict): """a dict which populates missing values via a creation function. - + note the creation function takes a key, unlike collections.defaultdict. """ - + def __init__(self, creator): self.creator = creator def __missing__(self, key): @@ -57,6 +125,61 @@ else: self[key] = value = self.creator(key) return value +try: + from collections import defaultdict +except ImportError: + class defaultdict(dict): + def __init__(self, default_factory=None, *a, **kw): + if (default_factory is not None and + not hasattr(default_factory, '__call__')): + raise TypeError('first argument must be callable') + dict.__init__(self, *a, **kw) + self.default_factory = default_factory + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + return self.__missing__(key) + def __missing__(self, key): + if self.default_factory is None: + raise KeyError(key) + self[key] = value = self.default_factory() + return value + def __reduce__(self): + if self.default_factory is None: + args = tuple() + else: + args = self.default_factory, + return type(self), args, None, None, self.iteritems() + def copy(self): + return self.__copy__() + def __copy__(self): + return type(self)(self.default_factory, self) + def __deepcopy__(self, memo): + import copy + return type(self)(self.default_factory, + copy.deepcopy(self.items())) + def __repr__(self): + return 'defaultdict(%s, %s)' % (self.default_factory, + dict.__repr__(self)) + +try: + from collections import deque +except ImportError: + class deque(list): + def appendleft(self, x): + self.insert(0, x) + + def extendleft(self, iterable): + self[0:0] = list(iterable) + + def popleft(self): + return self.pop(0) + + def rotate(self, n): + for i in xrange(n): + self.appendleft(self.pop()) + def to_list(x, default=None): if x is None: return default @@ -65,6 +188,18 @@ def to_list(x, default=None): else: return x +def array_as_starargs_decorator(func): + """Interpret a single positional array argument as + *args for the decorated method. + + """ + def starargs_as_list(self, *args, **kwargs): + if len(args) == 1: + return func(self, *to_list(args[0], []), **kwargs) + else: + return func(self, *args, **kwargs) + return starargs_as_list + def to_set(x): if x is None: return Set() @@ -73,6 +208,16 @@ def to_set(x): else: return x +def to_ascii(x): + """Convert Unicode or a string with unknown encoding into ASCII.""" + + if isinstance(x, str): + return x.encode('string_escape') + elif isinstance(x, unicode): + return x.encode('unicode_escape') + else: + raise TypeError + def flatten_iterator(x): """Given an iterator of which further sub-elements may also be iterators, flatten the sub-elements into a single iterator. @@ -85,20 +230,14 @@ def flatten_iterator(x): else: yield elem -def hash(string): - """return an md5 hash of the given string.""" - h = md5.new() - h.update(string) - return h.hexdigest() - - class ArgSingleton(type): - instances = {} + instances = weakref.WeakValueDictionary() - def dispose_static(self, *args): - hashkey = (self, args) - #if hashkey in ArgSingleton.instances: - del ArgSingleton.instances[hashkey] + def dispose(cls): + for key in list(ArgSingleton.instances): + if key[0] is cls: + del ArgSingleton.instances[key] + dispose = staticmethod(dispose) def __call__(self, *args): hashkey = (self, args) @@ -110,20 +249,57 @@ class ArgSingleton(type): return instance def get_cls_kwargs(cls): - """Return the full set of legal kwargs for the given `cls`.""" + """Return the full set of inherited kwargs for the given `cls`. + + Probes a class's __init__ method, collecting all named arguments. If the + __init__ defines a **kwargs catch-all, then the constructor is presumed to + pass along unrecognized keywords to it's base classes, and the collection + process is repeated recursively on each of the bases. + """ - kw = [] for c in cls.__mro__: - cons = c.__init__ - if hasattr(cons, 'func_code'): - for vn in cons.func_code.co_varnames: - if vn != 'self': - kw.append(vn) - return kw + if '__init__' in c.__dict__: + stack = Set([c]) + break + else: + return [] + + args = Set() + while stack: + class_ = stack.pop() + ctr = class_.__dict__.get('__init__', False) + if not ctr or not isinstance(ctr, types.FunctionType): + continue + names, _, has_kw, _ = inspect.getargspec(ctr) + args.update(names) + if has_kw: + stack.update(class_.__bases__) + args.discard('self') + return list(args) def get_func_kwargs(func): """Return the full set of legal kwargs for the given `func`.""" - return [vn for vn in func.func_code.co_varnames] + return inspect.getargspec(func)[0] + +def unbound_method_to_callable(func_or_cls): + """Adjust the incoming callable such that a 'self' argument is not required.""" + + if isinstance(func_or_cls, types.MethodType) and not func_or_cls.im_self: + return func_or_cls.im_func + else: + return func_or_cls + +# from paste.deploy.converters +def asbool(obj): + if isinstance(obj, (str, unicode)): + obj = obj.strip().lower() + if obj in ['true', 'yes', 'on', 'y', 't', '1']: + return True + elif obj in ['false', 'no', 'off', 'n', 'f', '0']: + return False + else: + raise ValueError("String is not true/false: %r" % obj) + return bool(obj) def coerce_kw_type(kw, key, type_, flexi_bool=True): """If 'key' is present in dict 'kw', coerce its value to type 'type_' if @@ -132,8 +308,8 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): """ if key in kw and type(kw[key]) is not type_ and kw[key] is not None: - if type_ is bool and flexi_bool and kw[key] == '0': - kw[key] = False + if type_ is bool and flexi_bool: + kw[key] = asbool(kw[key]) else: kw[key] = type_(kw[key]) @@ -142,13 +318,18 @@ def duck_type_collection(specimen, default=None): the basic collection types: list, set and dict. If the __emulates__ property is present, return that preferentially. """ - + if hasattr(specimen, '__emulates__'): - return specimen.__emulates__ + # canonicalize set vs sets.Set to a standard: util.Set + if (specimen.__emulates__ is not None and + issubclass(specimen.__emulates__, set_types)): + return Set + else: + return specimen.__emulates__ isa = isinstance(specimen, type) and issubclass or isinstance if isa(specimen, list): return list - if isa(specimen, Set): return Set + if isa(specimen, set_types): return Set if isa(specimen, dict): return dict if hasattr(specimen, 'append'): @@ -160,6 +341,30 @@ def duck_type_collection(specimen, default=None): else: return default +def dictlike_iteritems(dictlike): + """Return a (key, value) iterator for almost any dict-like object.""" + + if hasattr(dictlike, 'iteritems'): + return dictlike.iteritems() + elif hasattr(dictlike, 'items'): + return iter(dictlike.items()) + + getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None)) + if getter is None: + raise TypeError( + "Object '%r' is not dict-like" % dictlike) + + if hasattr(dictlike, 'iterkeys'): + def iterator(): + for key in dictlike.iterkeys(): + yield key, getter(key) + return iterator() + elif hasattr(dictlike, 'keys'): + return iter([(key, getter(key)) for key in dictlike.keys()]) + else: + raise TypeError( + "Object '%r' is not dict-like" % dictlike) + def assert_arg_type(arg, argtype, name): if isinstance(arg, argtype): return arg @@ -174,8 +379,37 @@ def warn_exception(func, *args, **kwargs): try: return func(*args, **kwargs) except: - warnings.warn(RuntimeWarning("%s('%s') ignored" % sys.exc_info()[0:2])) - + warn("%s('%s') ignored" % sys.exc_info()[0:2]) + +def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None, + name='self.proxy', from_instance=None): + """Automates delegation of __specials__ for a proxying type.""" + + if only: + dunders = only + else: + if skip is None: + skip = ('__slots__', '__del__', '__getattribute__', + '__metaclass__', '__getstate__', '__setstate__') + dunders = [m for m in dir(from_cls) + if (m.startswith('__') and m.endswith('__') and + not hasattr(into_cls, m) and m not in skip)] + for method in dunders: + try: + spec = inspect.getargspec(getattr(from_cls, method)) + fn_args = inspect.formatargspec(spec[0]) + d_args = inspect.formatargspec(spec[0][1:]) + except TypeError: + fn_args = '(self, *args, **kw)' + d_args = '(*args, **kw)' + + py = ("def %(method)s%(fn_args)s: " + "return %(name)s.%(method)s%(d_args)s" % locals()) + + env = from_instance is not None and {name: from_instance} or {} + exec py in env + setattr(into_cls, method, env[method]) + class SimpleProperty(object): """A *default* property accessor.""" @@ -194,12 +428,13 @@ class SimpleProperty(object): else: return getattr(obj, self.key) + class NotImplProperty(object): """a property that raises ``NotImplementedError``.""" - + def __init__(self, doc): self.__doc__ = doc - + def __set__(self, obj, value): raise NotImplementedError() @@ -211,7 +446,7 @@ class NotImplProperty(object): return self else: raise NotImplementedError() - + class OrderedProperties(object): """An object that maintains the order in which attributes are set upon it. @@ -247,7 +482,11 @@ class OrderedProperties(object): def __setattr__(self, key, object): self._data[key] = object - _data = property(lambda s:s.__dict__['_data']) + def __getstate__(self): + return {'_data': self.__dict__['_data']} + + def __setstate__(self, state): + self.__dict__['_data'] = state['_data'] def __getattr__(self, key): try: @@ -257,9 +496,12 @@ class OrderedProperties(object): def __contains__(self, key): return key in self._data - + + def update(self, value): + self._data.update(value) + def get(self, key, default=None): - if self.has_key(key): + if key in self: return self[key] else: return default @@ -279,7 +521,8 @@ class OrderedDict(dict): def __init__(self, ____sequence=None, **kwargs): self._list = [] if ____sequence is None: - self.update(**kwargs) + if kwargs: + self.update(**kwargs) else: self.update(____sequence, **kwargs) @@ -299,7 +542,7 @@ class OrderedDict(dict): self.update(kwargs) def setdefault(self, key, value): - if not self.has_key(key): + if key not in self: self.__setitem__(key, value) return value else: @@ -327,7 +570,7 @@ class OrderedDict(dict): return iter(self.items()) def __setitem__(self, key, object): - if not self.has_key(key): + if key not in self: self._list.append(key) dict.__setitem__(self, key, object) @@ -345,73 +588,60 @@ class OrderedDict(dict): self._list.remove(item[0]) return item -class ThreadLocal(object): - """An object in which attribute access occurs only within the context of the current thread.""" - - def __init__(self): - self.__dict__['_tdict'] = {} - - def __delattr__(self, key): - try: - del self._tdict["%d_%s" % (thread.get_ident(), key)] - except KeyError: - raise AttributeError(key) - - def __getattr__(self, key): - try: - return self._tdict["%d_%s" % (thread.get_ident(), key)] - except KeyError: - raise AttributeError(key) +try: + from threading import local as ThreadLocal +except ImportError: + try: + from dummy_threading import local as ThreadLocal + except ImportError: + class ThreadLocal(object): + """An object in which attribute access occurs only within the context of the current thread.""" - def __setattr__(self, key, value): - self._tdict["%d_%s" % (thread.get_ident(), key)] = value + def __init__(self): + self.__dict__['_tdict'] = {} -class DictDecorator(dict): - """A Dictionary that delegates items not found to a second wrapped dictionary.""" + def __delattr__(self, key): + try: + del self._tdict[(thread.get_ident(), key)] + except KeyError: + raise AttributeError(key) - def __init__(self, decorate): - self.decorate = decorate + def __getattr__(self, key): + try: + return self._tdict[(thread.get_ident(), key)] + except KeyError: + raise AttributeError(key) - def __getitem__(self, key): - try: - return dict.__getitem__(self, key) - except KeyError: - return self.decorate[key] - - def __contains__(self, key): - return dict.__contains__(self, key) or key in self.decorate - - def has_key(self, key): - return key in self - - def __repr__(self): - return dict.__repr__(self) + repr(self.decorate) + def __setattr__(self, key, value): + self._tdict[(thread.get_ident(), key)] = value class OrderedSet(Set): - def __init__(self, d=None, **kwargs): - super(OrderedSet, self).__init__(**kwargs) - self._list = [] - if d: self.update(d, **kwargs) + def __init__(self, d=None): + Set.__init__(self) + self._list = [] + if d is not None: + self.update(d) def add(self, key): - if key not in self: - self._list.append(key) - super(OrderedSet, self).add(key) + if key not in self: + self._list.append(key) + Set.add(self, key) def remove(self, element): - super(OrderedSet, self).remove(element) - self._list.remove(element) + Set.remove(self, element) + self._list.remove(element) def discard(self, element): - try: - super(OrderedSet, self).remove(element) - except KeyError: pass - else: - self._list.remove(element) + try: + Set.remove(self, element) + except KeyError: + pass + else: + self._list.remove(element) def clear(self): - super(OrderedSet, self).clear() - self._list=[] + Set.clear(self) + self._list = [] def __getitem__(self, key): return self._list[key] @@ -419,16 +649,19 @@ class OrderedSet(Set): def __iter__(self): return iter(self._list) - def update(self, iterable): - add = self.add - for i in iterable: add(i) - return self - def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self._list) __str__ = __repr__ + def update(self, iterable): + add = self.add + for i in iterable: + add(i) + return self + + __ior__ = update + def union(self, other): result = self.__class__(self) result.update(other) @@ -437,33 +670,35 @@ class OrderedSet(Set): __or__ = union def intersection(self, other): - return self.__class__([a for a in self if a in other]) + other = Set(other) + return self.__class__([a for a in self if a in other]) __and__ = intersection def symmetric_difference(self, other): - result = self.__class__([a for a in self if a not in other]) - result.update([a for a in other if a not in self]) - return result + other = Set(other) + result = self.__class__([a for a in self if a not in other]) + result.update([a for a in other if a not in self]) + return result __xor__ = symmetric_difference def difference(self, other): - return self.__class__([a for a in self if a not in other]) + other = Set(other) + return self.__class__([a for a in self if a not in other]) __sub__ = difference - __ior__ = update - def intersection_update(self, other): - super(OrderedSet, self).intersection_update(other) - self._list = [ a for a in self._list if a in other] - return self + other = Set(other) + Set.intersection_update(self, other) + self._list = [ a for a in self._list if a in other] + return self __iand__ = intersection_update def symmetric_difference_update(self, other): - super(OrderedSet, self).symmetric_difference_update(other) + Set.symmetric_difference_update(self, other) self._list = [ a for a in self._list if a in self] self._list += [ a for a in other._list if a in self] return self @@ -471,19 +706,262 @@ class OrderedSet(Set): __ixor__ = symmetric_difference_update def difference_update(self, other): - super(OrderedSet, self).difference_update(other) - self._list = [ a for a in self._list if a in self] - return self + Set.difference_update(self, other) + self._list = [ a for a in self._list if a in self] + return self __isub__ = difference_update + if hasattr(Set, '__getstate__'): + def __getstate__(self): + base = Set.__getstate__(self) + return base, self._list + + def __setstate__(self, state): + Set.__setstate__(self, state[0]) + self._list = state[1] + +class IdentitySet(object): + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + """ + + _working_set = Set + + def __init__(self, iterable=None): + self._members = _IterableUpdatableDict() + if iterable: + for o in iterable: + self.add(o) + + def add(self, value): + self._members[id(value)] = value + + def __contains__(self, value): + return id(value) in self._members + + def remove(self, value): + del self._members[id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError('pop from an empty set') + + def clear(self): + self._members.clear() + + def __cmp__(self, other): + raise TypeError('cannot compare sets using cmp()') + + def __eq__(self, other): + if isinstance(other, IdentitySet): + return self._members == other._members + else: + return False + + def __ne__(self, other): + if isinstance(other, IdentitySet): + return self._members != other._members + else: + return True + + def issubset(self, iterable): + other = type(self)(iterable) + + if len(self) > len(other): + return False + for m in itertools.ifilterfalse(other._members.has_key, + self._members.iterkeys()): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + def issuperset(self, iterable): + other = type(self)(iterable) + + if len(self) < len(other): + return False + + for m in itertools.ifilterfalse(self._members.has_key, + other._members.iterkeys()): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + def union(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._members.iteritems()).union(_iter_id(iterable))) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.union(other) + + def update(self, iterable): + self._members = self.union(iterable)._members + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + def difference(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._members.iteritems()).difference(_iter_id(iterable))) + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.difference(other) + + def difference_update(self, iterable): + self._members = self.difference(iterable)._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + def intersection(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._members.iteritems()).intersection(_iter_id(iterable))) + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.intersection(other) + + def intersection_update(self, iterable): + self._members = self.intersection(iterable)._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + def symmetric_difference(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._members.iteritems()).symmetric_difference(_iter_id(iterable))) + return result + + def __xor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference_update(self, iterable): + self._members = self.symmetric_difference(iterable)._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + def copy(self): + return type(self)(self._members.itervalues()) + + __copy__ = copy + + def __len__(self): + return len(self._members) + + def __iter__(self): + return self._members.itervalues() + + def __hash__(self): + raise TypeError('set objects are unhashable') + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, self._members.values()) + +if sys.version_info >= (2, 4): + _IterableUpdatableDict = dict +else: + class _IterableUpdatableDict(dict): + """A dict that can update(iterable) like Python 2.4+'s dict.""" + def update(self, __iterable=None, **kw): + if __iterable is not None: + if not isinstance(__iterable, dict): + __iterable = dict(__iterable) + dict.update(self, __iterable) + if kw: + dict.update(self, **kw) + +def _iter_id(iterable): + """Generator: ((id(o), o) for o in iterable).""" + for item in iterable: + yield id(item), item + + +class OrderedIdentitySet(IdentitySet): + class _working_set(OrderedSet): + # a testing pragma: exempt the OIDS working set from the test suite's + # "never call the user's __hash__" assertions. this is a big hammer, + # but it's safe here: IDS operates on (id, instance) tuples in the + # working set. + __sa_hash_exempt__ = True + + def __init__(self, iterable=None): + IdentitySet.__init__(self) + self._members = OrderedDict() + if iterable: + for o in iterable: + self.add(o) + class UniqueAppender(object): - """appends items to a collection such that only unique items - are added.""" - + """Only adds items to a collection once. + + Additional appends() of the same object are ignored. Membership is + determined by identity (``is a``) not equality (``==``). + """ + def __init__(self, data, via=None): self.data = data - self._unique = Set() + self._unique = IdentitySet() if via: self._data_appender = getattr(data, via) elif hasattr(data, 'append'): @@ -491,15 +969,15 @@ class UniqueAppender(object): elif hasattr(data, 'add'): # TODO: we think its a set here. bypass unneeded uniquing logic ? self._data_appender = data.add - + def append(self, item): if item not in self._unique: self._data_appender(item) self._unique.add(item) - + def __iter__(self): return iter(self.data) - + class ScopedRegistry(object): """A Registry that can store one or multiple instances of a single class on a per-thread scoped basis, or on a customized scope. @@ -528,6 +1006,9 @@ class ScopedRegistry(object): except KeyError: return self.registry.setdefault(key, self.createfunc()) + def has(self): + return self._get_key() in self.registry + def set(self, obj): self.registry[self._get_key()] = obj @@ -539,3 +1020,250 @@ class ScopedRegistry(object): def _get_key(self): return self.scopefunc() + +class _symbol(object): + def __init__(self, name): + """Construct a new named symbol.""" + assert isinstance(name, str) + self.name = name + def __reduce__(self): + return symbol, (self.name,) + def __repr__(self): + return "" % self.name +_symbol.__name__ = 'symbol' + +class symbol(object): + """A constant symbol. + + >>> symbol('foo') is symbol('foo') + True + >>> symbol('foo') + + + A slight refinement of the MAGICCOOKIE=object() pattern. The primary + advantage of symbol() is its repr(). They are also singletons. + + Repeated calls of symbol('name') will all return the same instance. + + """ + symbols = {} + _lock = threading.Lock() + + def __new__(cls, name): + cls._lock.acquire() + try: + sym = cls.symbols.get(name) + if sym is None: + cls.symbols[name] = sym = _symbol(name) + return sym + finally: + symbol._lock.release() + + +def as_interface(obj, cls=None, methods=None, required=None): + """Ensure basic interface compliance for an instance or dict of callables. + + Checks that ``obj`` implements public methods of ``cls`` or has members + listed in ``methods``. If ``required`` is not supplied, implementing at + least one interface method is sufficient. Methods present on ``obj`` that + are not in the interface are ignored. + + If ``obj`` is a dict and ``dict`` does not meet the interface + requirements, the keys of the dictionary are inspected. Keys present in + ``obj`` that are not in the interface will raise TypeErrors. + + Raises TypeError if ``obj`` does not meet the interface criteria. + + In all passing cases, an object with callable members is returned. In the + simple case, ``obj`` is returned as-is; if dict processing kicks in then + an anonymous class is returned. + + obj + A type, instance, or dictionary of callables. + cls + Optional, a type. All public methods of cls are considered the + interface. An ``obj`` instance of cls will always pass, ignoring + ``required``.. + methods + Optional, a sequence of method names to consider as the interface. + required + Optional, a sequence of mandatory implementations. If omitted, an + ``obj`` that provides at least one interface method is considered + sufficient. As a convenience, required may be a type, in which case + all public methods of the type are required. + + """ + if not cls and not methods: + raise TypeError('a class or collection of method names are required') + + if isinstance(cls, type) and isinstance(obj, cls): + return obj + + interface = Set(methods or [m for m in dir(cls) if not m.startswith('_')]) + implemented = Set(dir(obj)) + + complies = operator.ge + if isinstance(required, type): + required = interface + elif not required: + required = Set() + complies = operator.gt + else: + required = Set(required) + + if complies(implemented.intersection(interface), required): + return obj + + # No dict duck typing here. + if not type(obj) is dict: + qualifier = complies is operator.gt and 'any of' or 'all of' + raise TypeError("%r does not implement %s: %s" % ( + obj, qualifier, ', '.join(interface))) + + class AnonymousInterface(object): + """A callable-holding shell.""" + + if cls: + AnonymousInterface.__name__ = 'Anonymous' + cls.__name__ + found = Set() + + for method, impl in dictlike_iteritems(obj): + if method not in interface: + raise TypeError("%r: unknown in this interface" % method) + if not callable(impl): + raise TypeError("%r=%r is not callable" % (method, impl)) + setattr(AnonymousInterface, method, staticmethod(impl)) + found.add(method) + + if complies(found, required): + return AnonymousInterface + + raise TypeError("dictionary does not contain required keys %s" % + ', '.join(required - found)) + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + """ + try: + fn.__name__ = name + except TypeError: + fn = new.function(fn.func_code, fn.func_globals, name, + fn.func_defaults, fn.func_closure) + return fn + +def conditional_cache_decorator(func): + """apply conditional caching to the return value of a function.""" + + return cache_decorator(func, conditional=True) + +def cache_decorator(func, conditional=False): + """apply caching to the return value of a function.""" + + name = '_cached_' + func.__name__ + + def do_with_cache(self, *args, **kwargs): + if conditional: + cache = kwargs.pop('cache', False) + if not cache: + return func(self, *args, **kwargs) + try: + return getattr(self, name) + except AttributeError: + value = func(self, *args, **kwargs) + setattr(self, name, value) + return value + return do_with_cache + +def reset_cached(instance, name): + try: + delattr(instance, '_cached_' + name) + except AttributeError: + pass + +def warn(msg): + if isinstance(msg, basestring): + warnings.warn(msg, exceptions.SAWarning, stacklevel=3) + else: + warnings.warn(msg, stacklevel=3) + +def warn_deprecated(msg): + warnings.warn(msg, exceptions.SADeprecationWarning, stacklevel=3) + +def deprecated(message=None, add_deprecation_to_docstring=True): + """Decorates a function and issues a deprecation warning on use. + + message + If provided, issue message in the warning. A sensible default + is used if not provided. + + add_deprecation_to_docstring + Default True. If False, the wrapped function's __doc__ is left + as-is. If True, the 'message' is prepended to the docs if + provided, or sensible default if message is omitted. + """ + + if add_deprecation_to_docstring: + header = message is not None and message or 'Deprecated.' + else: + header = None + + if message is None: + message = "Call to deprecated function %(func)s" + + def decorate(fn): + return _decorate_with_warning( + fn, exceptions.SADeprecationWarning, + message % dict(func=fn.__name__), header) + return decorate + +def pending_deprecation(version, message=None, + add_deprecation_to_docstring=True): + """Decorates a function and issues a pending deprecation warning on use. + + version + An approximate future version at which point the pending deprecation + will become deprecated. Not used in messaging. + + message + If provided, issue message in the warning. A sensible default + is used if not provided. + + add_deprecation_to_docstring + Default True. If False, the wrapped function's __doc__ is left + as-is. If True, the 'message' is prepended to the docs if + provided, or sensible default if message is omitted. + """ + + if add_deprecation_to_docstring: + header = message is not None and message or 'Deprecated.' + else: + header = None + + if message is None: + message = "Call to deprecated function %(func)s" + + def decorate(fn): + return _decorate_with_warning( + fn, exceptions.SAPendingDeprecationWarning, + message % dict(func=fn.__name__), header) + return decorate + +def _decorate_with_warning(func, wtype, message, docstring_header=None): + """Wrap a function with a warnings.warn and augmented docstring.""" + + def func_with_warning(*args, **kwargs): + warnings.warn(wtype(message), stacklevel=2) + return func(*args, **kwargs) + + doc = func.__doc__ is not None and func.__doc__ or '' + if docstring_header is not None: + doc = '\n'.join((docstring_header.rstrip(), doc)) + + func_with_warning.__doc__ = doc + func_with_warning.__dict__.update(func.__dict__) + + return function_named(func_with_warning, func.__name__) diff --git a/setup.py b/setup.py index 735f3d7234..bb8e689fde 100644 --- a/setup.py +++ b/setup.py @@ -1,21 +1,66 @@ from ez_setup import use_setuptools use_setuptools() from setuptools import setup, find_packages +from distutils.command.build_py import build_py as _build_py +from setuptools.command.sdist import sdist as _sdist +import os +from os import path + +v = open(path.join(path.dirname(__file__), 'VERSION')) +VERSION = v.readline().strip() +v.close() + +class build_py(_build_py): + def run(self): + init = path.join(self.build_lib, 'sqlalchemy', '__init__.py') + if path.exists(init): + os.unlink(init) + _build_py.run(self) + _stamp_version(init) + self.byte_compile([init]) + +class sdist(_sdist): + def make_release_tree (self, base_dir, files): + _sdist.make_release_tree(self, base_dir, files) + orig = path.join('lib', 'sqlalchemy', '__init__.py') + assert path.exists(orig) + dest = path.join(base_dir, orig) + if hasattr(os, 'link') and path.exists(dest): + os.unlink(dest) + self.copy_file(orig, dest) + _stamp_version(dest) + +def _stamp_version(filename): + found, out = False, [] + f = open(filename, 'r') + for line in f: + if '__version__ =' in line: + line = line.replace("'svn'", "'%s'" % VERSION) + found = True + out.append(line) + f.close() + + if found: + f = open(filename, 'w') + f.writelines(out) + f.close() + setup(name = "SQLAlchemy", - version = "0.4.0", - description = "Database Abstraction Library", - author = "Mike Bayer", - author_email = "mike_mp@zzzcomputing.com", - url = "http://www.sqlalchemy.org", - packages = find_packages('lib'), - package_dir = {'':'lib'}, - entry_points = { - 'sqlalchemy.databases': [ - '%s = sqlalchemy.databases.%s:dialect' % (f,f) for f in - ['sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird']]}, - license = "MIT License", - long_description = """\ + cmdclass={'build_py': build_py, 'sdist': sdist}, + version = VERSION, + description = "Database Abstraction Library", + author = "Mike Bayer", + author_email = "mike_mp@zzzcomputing.com", + url = "http://www.sqlalchemy.org", + packages = find_packages('lib'), + package_dir = {'':'lib'}, + entry_points = { + 'sqlalchemy.databases': [ + '%s = sqlalchemy.databases.%s:dialect' % (f,f) for f in + ['sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird']]}, + license = "MIT License", + long_description = """\ SQLAlchemy is: * The Python SQL toolkit and Object Relational Mapper that gives application developers the full power and flexibility of SQL. SQLAlchemy provides a full suite of well known enterprise-level persistence patterns, designed for efficient and high-performing database access, adapted into a simple and Pythonic domain language. @@ -45,15 +90,11 @@ SVN version: """, - classifiers = [ + classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Topic :: Database :: Front-Ends", - ] - ) - - - - + ] + ) diff --git a/test/alltests.py b/test/alltests.py index e3266f5632..b08d0af13d 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest import orm.alltests as orm @@ -8,13 +8,17 @@ import engine.alltests as engine import dialect.alltests as dialect import ext.alltests as ext import zblog.alltests as zblog +import profiling.alltests as profiling + +# The profiling tests are sensitive to foibles of CPython VM state, so +# run them first. Ideally, each should be run in a fresh interpreter. def suite(): alltests = unittest.TestSuite() - for suite in (base, engine, sql, dialect, orm, ext, zblog): + for suite in (profiling, base, engine, sql, dialect, orm, ext, zblog): alltests.addTest(suite.suite()) return alltests if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/base/alltests.py b/test/base/alltests.py index 44fa9b2ecf..803b8ea3c4 100644 --- a/test/base/alltests.py +++ b/test/base/alltests.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest def suite(): @@ -6,6 +6,7 @@ def suite(): # core utilities 'base.dependency', 'base.utils', + 'base.except', ) alltests = unittest.TestSuite() for name in modules_to_test: @@ -17,4 +18,4 @@ def suite(): if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/base/dependency.py b/test/base/dependency.py index ddadd1b316..25d34ffd39 100644 --- a/test/base/dependency.py +++ b/test/base/dependency.py @@ -1,32 +1,26 @@ -import testbase +import testenv; testenv.configure_for_tests() import sqlalchemy.topological as topological from sqlalchemy import util from testlib import * -# TODO: need assertion conditions in this suite - - -class DependencySorter(topological.QueueDependencySorter):pass - - -class DependencySortTest(PersistTest): +class DependencySortTest(TestBase): def assert_sort(self, tuples, node, collection=None): print str(node) def assert_tuple(tuple, node): - if node.cycles: - cycles = [i.item for i in node.cycles] + if node[1]: + cycles = node[1] else: cycles = [] - if tuple[0] is node.item or tuple[0] in cycles: + if tuple[0] is node[0] or tuple[0] in cycles: tuple.pop() - if tuple[0] is node.item or tuple[0] in cycles: + if tuple[0] is node[0] or tuple[0] in cycles: return - elif len(tuple) > 1 and tuple[1] is node.item: - assert False, "Tuple not in dependency tree: " + str(tuple) - for c in node.children: + elif len(tuple) > 1 and tuple[1] is node[0]: + assert False, "Tuple not in dependency tree: " + str(tuple) + " " + str(node) + for c in node[2]: assert_tuple(tuple, c) - + for tuple in tuples: assert_tuple(list(tuple), node) @@ -34,16 +28,16 @@ class DependencySortTest(PersistTest): collection = [] items = util.Set() def assert_unique(node): - for item in [n.item for n in node.cycles or [node,]]: + for item in [i for i in node[1] or [node[0]]]: assert item not in items items.add(item) if item in collection: collection.remove(item) - for c in node.children: + for c in node[2]: assert_unique(c) assert_unique(node) assert len(collection) == 0 - + def testsort(self): rootnode = 'root' node2 = 'node2' @@ -64,7 +58,7 @@ class DependencySortTest(PersistTest): (node4, subnode3), (node4, subnode4) ] - head = DependencySorter(tuples, []).sort() + head = topological.sort_as_tree(tuples, []) self.assert_sort(tuples, head) def testsort2(self): @@ -82,7 +76,7 @@ class DependencySortTest(PersistTest): (node5, node6), (node6, node2) ] - head = DependencySorter(tuples, [node7]).sort() + head = topological.sort_as_tree(tuples, [node7]) self.assert_sort(tuples, head, [node7]) def testsort3(self): @@ -95,10 +89,10 @@ class DependencySortTest(PersistTest): (node3, node2), (node1,node3) ] - head1 = DependencySorter(tuples, [node1, node2, node3]).sort() - head2 = DependencySorter(tuples, [node3, node1, node2]).sort() - head3 = DependencySorter(tuples, [node3, node2, node1]).sort() - + head1 = topological.sort_as_tree(tuples, [node1, node2, node3]) + head2 = topological.sort_as_tree(tuples, [node3, node1, node2]) + head3 = topological.sort_as_tree(tuples, [node3, node2, node1]) + # TODO: figure out a "node == node2" function #self.assert_(str(head1) == str(head2) == str(head3)) print "\n" + str(head1) @@ -116,11 +110,11 @@ class DependencySortTest(PersistTest): (node1, node3), (node3, node2) ] - head = DependencySorter(tuples, []).sort() + head = topological.sort_as_tree(tuples, []) self.assert_sort(tuples, head) def testsort5(self): - # this one, depenending on the weather, + # this one, depenending on the weather, node1 = 'node1' #'00B94190' node2 = 'node2' #'00B94990' node3 = 'node3' #'00B9A9B0' @@ -139,7 +133,7 @@ class DependencySortTest(PersistTest): node3, node4 ] - head = DependencySorter(tuples, allitems).sort() + head = topological.sort_as_tree(tuples, allitems, with_cycles=True) self.assert_sort(tuples, head) def testcircular(self): @@ -156,9 +150,10 @@ class DependencySortTest(PersistTest): (node3, node1), (node4, node1) ] - head = DependencySorter(tuples, []).sort(allow_all_cycles=True) + allitems = [node1, node2, node3, node4] + head = topological.sort_as_tree(tuples, allitems, with_cycles=True) self.assert_sort(tuples, head) - + def testcircular2(self): # this condition was arising from ticket:362 # and was not treated properly by topological sort @@ -173,22 +168,21 @@ class DependencySortTest(PersistTest): (node3, node2), (node2, node3) ] - head = DependencySorter(tuples, []).sort(allow_all_cycles=True) + head = topological.sort_as_tree(tuples, [], with_cycles=True) self.assert_sort(tuples, head) - + def testcircular3(self): nodes = {} tuples = [('Question', 'Issue'), ('ProviderService', 'Issue'), ('Provider', 'Question'), ('Question', 'Provider'), ('ProviderService', 'Question'), ('Provider', 'ProviderService'), ('Question', 'Answer'), ('Issue', 'Question')] - head = DependencySorter(tuples, []).sort(allow_all_cycles=True) + head = topological.sort_as_tree(tuples, [], with_cycles=True) self.assert_sort(tuples, head) - + def testbigsort(self): tuples = [] for i in range(0,1500, 2): tuples.append((i, i+1)) - head = DependencySorter(tuples, []).sort() - - - + head = topological.sort_as_tree(tuples, []) + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/base/except.py b/test/base/except.py new file mode 100644 index 0000000000..84b84793ce --- /dev/null +++ b/test/base/except.py @@ -0,0 +1,77 @@ +"""Tests exceptions and DB-API exception wrapping.""" +import testenv; testenv.configure_for_tests() +import sys, unittest +import exceptions as stdlib_exceptions +from sqlalchemy import exceptions as sa_exceptions +from testlib import * + + +class Error(stdlib_exceptions.StandardError): + """This class will be old-style on <= 2.4 and new-style on >= 2.5.""" +class DatabaseError(Error): + pass +class OperationalError(DatabaseError): + pass +class ProgrammingError(DatabaseError): + def __str__(self): + return "<%s>" % self.bogus +class OutOfSpec(DatabaseError): + pass + + +class WrapTest(unittest.TestCase): + def test_db_error_normal(self): + try: + raise sa_exceptions.DBAPIError.instance( + '', [], OperationalError()) + except sa_exceptions.DBAPIError: + self.assert_(True) + + def test_db_error_busted_dbapi(self): + try: + raise sa_exceptions.DBAPIError.instance( + '', [], ProgrammingError()) + except sa_exceptions.DBAPIError, e: + self.assert_(True) + self.assert_('Error in str() of DB-API' in e.args[0]) + + def test_db_error_noncompliant_dbapi(self): + try: + raise sa_exceptions.DBAPIError.instance( + '', [], OutOfSpec()) + except sa_exceptions.DBAPIError, e: + self.assert_(e.__class__ is sa_exceptions.DBAPIError) + except OutOfSpec: + self.assert_(False) + + # Make sure the DatabaseError recognition logic is limited to + # subclasses of sqlalchemy.exceptions.DBAPIError + try: + raise sa_exceptions.DBAPIError.instance( + '', [], sa_exceptions.AssertionError()) + except sa_exceptions.DBAPIError, e: + self.assert_(e.__class__ is sa_exceptions.DBAPIError) + except sa_exceptions.AssertionError: + self.assert_(False) + + def test_db_error_keyboard_interrupt(self): + try: + raise sa_exceptions.DBAPIError.instance( + '', [], stdlib_exceptions.KeyboardInterrupt()) + except sa_exceptions.DBAPIError: + self.assert_(False) + except stdlib_exceptions.KeyboardInterrupt: + self.assert_(True) + + def test_db_error_system_exit(self): + try: + raise sa_exceptions.DBAPIError.instance( + '', [], stdlib_exceptions.SystemExit()) + except sa_exceptions.DBAPIError: + self.assert_(False) + except stdlib_exceptions.SystemExit: + self.assert_(True) + + +if __name__ == "__main__": + testenv.main() diff --git a/test/base/utils.py b/test/base/utils.py index 97f3db06fc..a00338f5f5 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -1,9 +1,10 @@ -import testbase -from sqlalchemy import util, column, sql, exceptions +import testenv; testenv.configure_for_tests() +import unittest +from sqlalchemy import util, sql, exceptions from testlib import * +from testlib import sorted - -class OrderedDictTest(PersistTest): +class OrderedDictTest(TestBase): def test_odict(self): o = util.OrderedDict() o['a'] = 1 @@ -13,7 +14,7 @@ class OrderedDictTest(PersistTest): self.assert_(o.keys() == ['a', 'b', 'snack', 'c']) self.assert_(o.values() == [1, 2, 'attack', 3]) - + o.pop('snack') self.assert_(o.keys() == ['a', 'b', 'c']) @@ -34,12 +35,24 @@ class OrderedDictTest(PersistTest): self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e', 'f']) self.assert_(o.values() == [1, 2, 3, 4, 5, 6]) -class ColumnCollectionTest(PersistTest): +class OrderedSetTest(TestBase): + def test_mutators_against_iter(self): + # testing a set modified against an iterator + o = util.OrderedSet([3,2, 4, 5]) + + self.assertEquals(o.difference(iter([3,4])), + util.OrderedSet([2,5])) + self.assertEquals(o.intersection(iter([3,4, 6])), + util.OrderedSet([3, 4])) + self.assertEquals(o.union(iter([3,4, 6])), + util.OrderedSet([2, 3, 4, 5, 6])) + +class ColumnCollectionTest(TestBase): def test_in(self): cc = sql.ColumnCollection() - cc.add(column('col1')) - cc.add(column('col2')) - cc.add(column('col3')) + cc.add(sql.column('col1')) + cc.add(sql.column('col2')) + cc.add(sql.column('col3')) assert 'col1' in cc assert 'col2' in cc @@ -48,20 +61,458 @@ class ColumnCollectionTest(PersistTest): assert False except exceptions.ArgumentError, e: assert str(e) == "__contains__ requires a string argument" - + def test_compare(self): cc1 = sql.ColumnCollection() cc2 = sql.ColumnCollection() cc3 = sql.ColumnCollection() - c1 = column('col1') + c1 = sql.column('col1') c2 = c1.label('col2') - c3 = column('col3') + c3 = sql.column('col3') cc1.add(c1) cc2.add(c2) cc3.add(c3) assert (cc1==cc2).compare(c1 == c2) assert not (cc1==cc3).compare(c2 == c3) - - + +class ArgSingletonTest(unittest.TestCase): + def test_cleanout(self): + util.ArgSingleton.instances.clear() + + class MyClass(object): + __metaclass__ = util.ArgSingleton + def __init__(self, x, y): + self.x = x + self.y = y + + m1 = MyClass(3, 4) + m2 = MyClass(1, 5) + m3 = MyClass(3, 4) + assert m1 is m3 + assert m2 is not m3 + assert len(util.ArgSingleton.instances) == 2 + + m1 = m2 = m3 = None + MyClass.dispose(MyClass) + assert len(util.ArgSingleton.instances) == 0 + + +class ImmutableSubclass(str): + pass + +class HashOverride(object): + def __init__(self, value=None): + self.value = value + def __hash__(self): + return hash(self.value) + +class EqOverride(object): + def __init__(self, value=None): + self.value = value + def __eq__(self, other): + if isinstance(other, EqOverride): + return self.value == other.value + else: + return False + def __ne__(self, other): + if isinstance(other, EqOverride): + return self.value != other.value + else: + return True +class HashEqOverride(object): + def __init__(self, value=None): + self.value = value + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, EqOverride): + return self.value == other.value + else: + return False + def __ne__(self, other): + if isinstance(other, EqOverride): + return self.value != other.value + else: + return True + + +class IdentitySetTest(unittest.TestCase): + def assert_eq(self, identityset, expected_iterable): + expected = sorted([id(o) for o in expected_iterable]) + found = sorted([id(o) for o in identityset]) + self.assertEquals(found, expected) + + def test_init(self): + ids = util.IdentitySet([1,2,3,2,1]) + self.assert_eq(ids, [1,2,3]) + + ids = util.IdentitySet(ids) + self.assert_eq(ids, [1,2,3]) + + ids = util.IdentitySet() + self.assert_eq(ids, []) + + ids = util.IdentitySet([]) + self.assert_eq(ids, []) + + ids = util.IdentitySet(ids) + self.assert_eq(ids, []) + + def test_add(self): + for type_ in (object, ImmutableSubclass): + data = [type_(), type_()] + ids = util.IdentitySet() + for i in range(2) + range(2): + ids.add(data[i]) + self.assert_eq(ids, data) + + for type_ in (EqOverride, HashOverride, HashEqOverride): + data = [type_(1), type_(1), type_(2)] + ids = util.IdentitySet() + for i in range(3) + range(3): + ids.add(data[i]) + self.assert_eq(ids, data) + + def test_basic_sanity(self): + IdentitySet = util.IdentitySet + + o1, o2, o3 = object(), object(), object() + ids = IdentitySet([o1]) + ids.discard(o1) + ids.discard(o1) + ids.add(o1) + ids.remove(o1) + self.assertRaises(KeyError, ids.remove, o1) + + self.assert_(ids.copy() == ids) + self.assert_(ids != None) + self.assert_(not(ids == None)) + self.assert_(ids != IdentitySet([o1,o2,o3])) + ids.clear() + self.assert_(o1 not in ids) + ids.add(o2) + self.assert_(o2 in ids) + self.assert_(ids.pop() == o2) + ids.add(o1) + self.assert_(len(ids) == 1) + + isuper = IdentitySet([o1,o2]) + self.assert_(ids < isuper) + self.assert_(ids.issubset(isuper)) + self.assert_(isuper.issuperset(ids)) + self.assert_(isuper > ids) + + self.assert_(ids.union(isuper) == isuper) + self.assert_(ids | isuper == isuper) + self.assert_(isuper - ids == IdentitySet([o2])) + self.assert_(isuper.difference(ids) == IdentitySet([o2])) + self.assert_(ids.intersection(isuper) == IdentitySet([o1])) + self.assert_(ids & isuper == IdentitySet([o1])) + self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2])) + self.assert_(ids ^ isuper == IdentitySet([o2])) + + ids.update(isuper) + ids |= isuper + ids.difference_update(isuper) + ids -= isuper + ids.intersection_update(isuper) + ids &= isuper + ids.symmetric_difference_update(isuper) + ids ^= isuper + + ids.update('foobar') + try: + ids |= 'foobar' + self.assert_(False) + except TypeError: + self.assert_(True) + + try: + s = set([o1,o2]) + s |= ids + self.assert_(False) + except TypeError: + self.assert_(True) + + self.assertRaises(TypeError, cmp, ids) + self.assertRaises(TypeError, hash, ids) + + def test_difference(self): + os1 = util.IdentitySet([1,2,3]) + os2 = util.IdentitySet([3,4,5]) + s1 = set([1,2,3]) + s2 = set([3,4,5]) + + self.assertEquals(os1 - os2, util.IdentitySet([1, 2])) + self.assertEquals(os2 - os1, util.IdentitySet([4, 5])) + self.assertRaises(TypeError, lambda: os1 - s2) + self.assertRaises(TypeError, lambda: os1 - [3, 4, 5]) + self.assertRaises(TypeError, lambda: s1 - os2) + self.assertRaises(TypeError, lambda: s1 - [3, 4, 5]) + + +class DictlikeIteritemsTest(unittest.TestCase): + baseline = set([('a', 1), ('b', 2), ('c', 3)]) + + def _ok(self, instance): + iterator = util.dictlike_iteritems(instance) + self.assertEquals(set(iterator), self.baseline) + + def _notok(self, instance): + self.assertRaises(TypeError, + util.dictlike_iteritems, + instance) + + def test_dict(self): + d = dict(a=1,b=2,c=3) + self._ok(d) + + def test_subdict(self): + class subdict(dict): + pass + d = subdict(a=1,b=2,c=3) + self._ok(d) + + def test_UserDict(self): + import UserDict + d = UserDict.UserDict(a=1,b=2,c=3) + self._ok(d) + + def test_object(self): + self._notok(object()) + + def test_duck_1(self): + class duck1(object): + def iteritems(duck): + return iter(self.baseline) + self._ok(duck1()) + + def test_duck_2(self): + class duck2(object): + def items(duck): + return list(self.baseline) + self._ok(duck2()) + + def test_duck_3(self): + class duck3(object): + def iterkeys(duck): + return iter(['a', 'b', 'c']) + def __getitem__(duck, key): + return dict(a=1,b=2,c=3).get(key) + self._ok(duck3()) + + def test_duck_4(self): + class duck4(object): + def iterkeys(duck): + return iter(['a', 'b', 'c']) + self._notok(duck4()) + + def test_duck_5(self): + class duck5(object): + def keys(duck): + return ['a', 'b', 'c'] + def get(duck, key): + return dict(a=1,b=2,c=3).get(key) + self._ok(duck5()) + + def test_duck_6(self): + class duck6(object): + def keys(duck): + return ['a', 'b', 'c'] + self._notok(duck6()) + + +class ArgInspectionTest(TestBase): + def test_get_cls_kwargs(self): + class A(object): + def __init__(self, a): + pass + class A1(A): + def __init__(self, a1): + pass + class A11(A1): + def __init__(self, a11, **kw): + pass + class B(object): + def __init__(self, b, **kw): + pass + class B1(B): + def __init__(self, b1, **kw): + pass + class AB(A, B): + def __init__(self, ab): + pass + class BA(B, A): + def __init__(self, ba, **kwargs): + pass + class BA1(BA): + pass + class CAB(A, B): + pass + class CBA(B, A): + pass + class CAB1(A, B1): + pass + class CB1A(B1, A): + pass + class D(object): + pass + + def test(cls, *expected): + self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected)) + + test(A, 'a') + test(A1, 'a1') + test(A11, 'a11', 'a1') + test(B, 'b') + test(B1, 'b1', 'b') + test(AB, 'ab') + test(BA, 'ba', 'b', 'a') + test(BA1, 'ba', 'b', 'a') + test(CAB, 'a') + test(CBA, 'b') + test(CAB1, 'a') + test(CB1A, 'b1', 'b') + test(D) + + def test_get_func_kwargs(self): + def f1(): pass + def f2(foo): pass + def f3(*foo): pass + def f4(**foo): pass + + def test(fn, *expected): + self.assertEquals(set(util.get_func_kwargs(fn)), set(expected)) + + test(f1) + test(f2, 'foo') + test(f3) + test(f4) + +class SymbolTest(TestBase): + def test_basic(self): + sym1 = util.symbol('foo') + assert sym1.name == 'foo' + sym2 = util.symbol('foo') + + assert sym1 is sym2 + assert sym1 == sym2 + + sym3 = util.symbol('bar') + assert sym1 is not sym3 + assert sym1 != sym3 + + def test_pickle(self): + sym1 = util.symbol('foo') + sym2 = util.symbol('foo') + + assert sym1 is sym2 + + # default + s = util.pickle.dumps(sym1) + sym3 = util.pickle.loads(s) + + for protocol in 0, 1, 2: + print protocol + serial = util.pickle.dumps(sym1) + rt = util.pickle.loads(serial) + assert rt is sym1 + assert rt is sym2 + +class AsInterfaceTest(TestBase): + class Something(object): + def _ignoreme(self): pass + def foo(self): pass + def bar(self): pass + + class Partial(object): + def bar(self): pass + + class Object(object): pass + + def test_instance(self): + obj = object() + self.assertRaises(TypeError, util.as_interface, obj, + cls=self.Something) + + self.assertRaises(TypeError, util.as_interface, obj, + methods=('foo')) + + self.assertRaises(TypeError, util.as_interface, obj, + cls=self.Something, required=('foo')) + + obj = self.Something() + self.assertEqual(obj, util.as_interface(obj, cls=self.Something)) + self.assertEqual(obj, util.as_interface(obj, methods=('foo',))) + self.assertEqual( + obj, util.as_interface(obj, cls=self.Something, + required=('outofband',))) + partial = self.Partial() + + slotted = self.Object() + slotted.bar = lambda self: 123 + + for obj in partial, slotted: + self.assertEqual(obj, util.as_interface(obj, cls=self.Something)) + self.assertRaises(TypeError, util.as_interface, obj, + methods=('foo')) + self.assertEqual(obj, util.as_interface(obj, methods=('bar',))) + self.assertEqual( + obj, util.as_interface(obj, cls=self.Something, + required=('bar',))) + self.assertRaises(TypeError, util.as_interface, obj, + cls=self.Something, required=('foo',)) + + self.assertRaises(TypeError, util.as_interface, obj, + cls=self.Something, required=self.Something) + + def test_dict(self): + obj = {} + + self.assertRaises(TypeError, util.as_interface, obj, + cls=self.Something) + self.assertRaises(TypeError, util.as_interface, obj, + methods=('foo')) + self.assertRaises(TypeError, util.as_interface, obj, + cls=self.Something, required=('foo')) + + def assertAdapted(obj, *methods): + assert isinstance(obj, type) + found = set([m for m in dir(obj) if not m.startswith('_')]) + for method in methods: + assert method in found + found.remove(method) + assert not found + + fn = lambda self: 123 + + obj = {'foo': fn, 'bar': fn} + + res = util.as_interface(obj, cls=self.Something) + assertAdapted(res, 'foo', 'bar') + + res = util.as_interface(obj, cls=self.Something, required=self.Something) + assertAdapted(res, 'foo', 'bar') + + res = util.as_interface(obj, cls=self.Something, required=('foo',)) + assertAdapted(res, 'foo', 'bar') + + res = util.as_interface(obj, methods=('foo', 'bar')) + assertAdapted(res, 'foo', 'bar') + + res = util.as_interface(obj, methods=('foo', 'bar', 'baz')) + assertAdapted(res, 'foo', 'bar') + + res = util.as_interface(obj, methods=('foo', 'bar'), required=('foo',)) + assertAdapted(res, 'foo', 'bar') + + self.assertRaises(TypeError, util.as_interface, obj, methods=('foo',)) + + self.assertRaises(TypeError, util.as_interface, obj, + methods=('foo', 'bar', 'baz'), required=('baz',)) + + obj = {'foo': 123} + self.assertRaises(TypeError, util.as_interface, obj, cls=self.Something) + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/clone.py b/test/clone.py new file mode 100644 index 0000000000..f56ab8cf18 --- /dev/null +++ b/test/clone.py @@ -0,0 +1,175 @@ +# only tested with cpython! +import optparse, os, shutil, sys +from os import path +from testlib import filters + +__doc__ = """ +Creates and maintains a 'clone' of the test suite, optionally transforming +the source code through a filter. The primary purpose of this utility is +to allow the tests to run on Python VMs that do not implement a parser that +groks 2.4 style @decorations. + +Creating a clone: + + Create a new, exact clone of the suite: + $ python test/clone.py -c myclone + + Create a new clone using the 2.3 filter: + $ python test/clone.py -c --filter=py23 myclone + +After the clone is set up, changes in the master can be pulled into the clone +with the -u or --update switch. If the clone was created with a filter, it +will be applied automatically when updating. + + Update the clone: + $ python test/clone.py -u myclone + +The updating algorithm is very simple: if the version in test/ is newer than +the one in your clone, the clone version is overwritten. +""" + +options = None +clone, clone_path = None, None +filter = lambda x: x[:] + +def optparser(): + parser = optparse.OptionParser( + usage=('usage: %prog [options] CLONE-NAME\n' + __doc__ ).rstrip()) + parser.add_option('-n', '--dry-run', dest='dryrun', + action='store_true', + help=('Do not actually change any files; ' + 'just print what would happen.')) + parser.add_option('-u', '--update', dest='update', action='store_true', + help='Update an existing clone.') + parser.add_option('-c', '--create', dest='create', action='store_true', + help='Create a new clone.') + parser.add_option('--filter', dest='filter', + help='Run source code through a filter.') + parser.add_option('-l', '--filter-list', dest='filter_list', + action='store_true', + help='Show available filters.') + parser.add_option('-f', '--force', dest='force', action='store_true', + help='Overwrite clone files even if unchanged.') + parser.add_option('-q', '--quiet', dest='quiet', action='store_true', + help='Run quietly.') + parser.set_defaults(update=False, create=False, + dryrun=False, filter_list=False, + force=False, quiet=False) + return parser + +def config(): + global clone, clone_path, options, filter + + parser = optparser() + (options, args) = parser.parse_args() + + if options.filter_list: + if options.quiet: + print '\n'.join(filters.__all__) + else: + print 'Available filters:' + for name in filters.__all__: + print '\t%s' % name + sys.exit(0) + + if not options.update and not options.create: + parser.error('One of -u or -c is required.') + + if len(args) != 1: + parser.error('A clone name is required.') + + clone = args[0] + clone_path = path.abspath(clone) + + if options.update and not path.exists(clone_path): + parser.error( + 'Clone %s does not exist; create it with --create first.' % clone) + if options.create and path.exists(clone_path): + parser.error('Clone %s already exists.' % clone) + + if options.filter: + if options.filter not in filters.__all__: + parser.error(('Filter "%s" unknown; use --filter-list to see ' + 'available filters.') % options.filter) + filter = getattr(filters, options.filter) + +def setup(): + global filter + + if options.create: + if not options.quiet: + print "mkdir %s" % clone_path + if not options.dryrun: + os.mkdir(clone_path) + + if options.filter and not options.dryrun: + if not options.quiet: + print 'storing filter "%s" in %s/.filter' % ( + options.filter, clone) + stash = open(path.join(clone_path, '.filter'), 'w') + stash.write(options.filter) + stash.close() + else: + stash_file = path.join(clone_path, '.filter') + if path.exists(stash_file): + stash = open(stash_file) + stashed = stash.read().strip() + stash.close() + if options.filter: + if (options.filter != stashed and stashed in filters.__all__ and + not options.quiet): + print (('Warning: --filter=%s overrides %s specified in ' + '%s/.filter') % (options.filter, stashed, clone)) + else: + if stashed not in filters.__all__: + sys.stderr.write( + 'Filter "%s" in %s/.filter is not valid, aborting.' % + (stashed, clone)) + sys.exit(-1) + filter = getattr(filters, stashed) + +def sync(): + source_path, _ = path.split(path.abspath(__file__)) + + ls = lambda root: [fn + for fn in os.listdir(root) + if (fn.endswith('.py') and not fn.startswith('.'))] + + def walker(x, dirname, fnames): + if '.svn' in fnames: + fnames.remove('.svn') + + rel_path = dirname[len(source_path) + 1:] + dest_path = path.join(clone_path, rel_path) + + if not path.exists(dest_path): + if not options.quiet: + print "mkdir %s/%s" % (clone, rel_path) + if not options.dryrun: + os.mkdir(dest_path) + + for filename in ls(dirname): + source_file = path.join(source_path, rel_path, filename) + dest_file = path.join(dest_path, filename) + + if (options.force or + (not path.exists(dest_file) or + os.stat(source_file)[-1] > os.stat(dest_file)[-1])): + if not options.quiet: + print "syncing %s" % path.join(rel_path, filename) + + raw = open(source_file) + filtered = filter(raw.readlines()) + raw.close() + + if not options.dryrun: + synced = open(dest_file, 'w') + synced.writelines(filtered) + synced.close() + + os.path.walk(source_path, walker, None) + +if __name__ == '__main__': + config() + setup() + sync() diff --git a/test/dialect/access.py b/test/dialect/access.py new file mode 100644 index 0000000000..311231947e --- /dev/null +++ b/test/dialect/access.py @@ -0,0 +1,15 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.databases import access +from testlib import * + + +class BasicTest(TestBase, AssertsExecutionResults): + # A simple import of the database/ module should work on all systems. + def test_import(self): + # we got this far, right? + return True + + +if __name__ == "__main__": + testenv.main() diff --git a/test/dialect/alltests.py b/test/dialect/alltests.py index 8900736259..d40f0d6d46 100644 --- a/test/dialect/alltests.py +++ b/test/dialect/alltests.py @@ -1,11 +1,19 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest + def suite(): modules_to_test = ( + 'dialect.access', + 'dialect.firebird', + 'dialect.informix', + 'dialect.maxdb', + 'dialect.mssql', 'dialect.mysql', - 'dialect.postgres', 'dialect.oracle', + 'dialect.postgres', + 'dialect.sqlite', + 'dialect.sybase', ) alltests = unittest.TestSuite() for name in modules_to_test: @@ -16,6 +24,5 @@ def suite(): return alltests - if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/dialect/firebird.py b/test/dialect/firebird.py new file mode 100644 index 0000000000..f929443fd4 --- /dev/null +++ b/test/dialect/firebird.py @@ -0,0 +1,121 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.databases import firebird +from sqlalchemy.exceptions import ProgrammingError +from sqlalchemy.sql import table, column +from testlib import * + + +class DomainReflectionTest(TestBase, AssertsExecutionResults): + "Test Firebird domains" + + __only_on__ = 'firebird' + + def setUpAll(self): + con = testing.db.connect() + try: + con.execute('CREATE DOMAIN int_domain AS INTEGER DEFAULT 42 NOT NULL') + con.execute('CREATE DOMAIN str_domain AS VARCHAR(255)') + con.execute('CREATE DOMAIN rem_domain AS BLOB SUB_TYPE TEXT') + con.execute('CREATE DOMAIN img_domain AS BLOB SUB_TYPE BINARY') + except ProgrammingError, e: + if not "attempt to store duplicate value" in str(e): + raise e + con.execute('''CREATE GENERATOR gen_testtable_id''') + con.execute('''CREATE TABLE testtable (question int_domain, + answer str_domain DEFAULT 'no answer', + remark rem_domain DEFAULT '', + photo img_domain, + d date, + t time, + dt timestamp)''') + con.execute('''ALTER TABLE testtable + ADD CONSTRAINT testtable_pk PRIMARY KEY (question)''') + con.execute('''CREATE TRIGGER testtable_autoid FOR testtable + ACTIVE BEFORE INSERT AS + BEGIN + IF (NEW.question IS NULL) THEN + NEW.question = gen_id(gen_testtable_id, 1); + END''') + + def tearDownAll(self): + con = testing.db.connect() + con.execute('DROP TABLE testtable') + con.execute('DROP DOMAIN int_domain') + con.execute('DROP DOMAIN str_domain') + con.execute('DROP DOMAIN rem_domain') + con.execute('DROP DOMAIN img_domain') + con.execute('DROP GENERATOR gen_testtable_id') + + def test_table_is_reflected(self): + metadata = MetaData(testing.db) + table = Table('testtable', metadata, autoload=True) + self.assertEquals(set(table.columns.keys()), + set(['question', 'answer', 'remark', 'photo', 'd', 't', 'dt']), + "Columns of reflected table didn't equal expected columns") + self.assertEquals(table.c.question.primary_key, True) + self.assertEquals(table.c.question.sequence.name, 'gen_testtable_id') + self.assertEquals(table.c.question.type.__class__, firebird.FBInteger) + self.assertEquals(table.c.question.default.arg.text, "42") + self.assertEquals(table.c.answer.type.__class__, firebird.FBString) + self.assertEquals(table.c.answer.default.arg.text, "'no answer'") + self.assertEquals(table.c.remark.type.__class__, firebird.FBText) + self.assertEquals(table.c.remark.default.arg.text, "''") + self.assertEquals(table.c.photo.type.__class__, firebird.FBBinary) + # The following assume a Dialect 3 database + self.assertEquals(table.c.d.type.__class__, firebird.FBDate) + self.assertEquals(table.c.t.type.__class__, firebird.FBTime) + self.assertEquals(table.c.dt.type.__class__, firebird.FBDateTime) + + +class CompileTest(TestBase, AssertsCompiledSQL): + __dialect__ = firebird.FBDialect() + + def test_alias(self): + t = table('sometable', column('col1'), column('col2')) + s = select([t.alias()]) + self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1") + + def test_function(self): + self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)") + self.assert_compile(func.current_time(), "CURRENT_TIME") + self.assert_compile(func.foo(), "foo") + + m = MetaData() + t = Table('sometable', m, Column('col1', Integer), Column('col2', Integer)) + self.assert_compile(select([func.max(t.c.col1)]), "SELECT max(sometable.col1) AS max_1 FROM sometable") + + def test_substring(self): + self.assert_compile(func.substring('abc', 1, 2), "SUBSTRING(:substring_1 FROM :substring_2 FOR :substring_3)") + self.assert_compile(func.substring('abc', 1), "SUBSTRING(:substring_1 FROM :substring_2)") + +class MiscFBTests(TestBase): + + __only_on__ = 'firebird' + + def test_strlen(self): + # On FB the length() function is implemented by an external + # UDF, strlen(). Various SA tests fail because they pass a + # parameter to it, and that does not work (it always results + # the maximum string length the UDF was declared to accept). + # This test checks that at least it works ok in other cases. + + meta = MetaData(testing.db) + t = Table('t1', meta, + Column('id', Integer, Sequence('t1idseq'), primary_key=True), + Column('name', String(10)) + ) + meta.create_all() + try: + t.insert(values=dict(name='dante')).execute() + t.insert(values=dict(name='alighieri')).execute() + select([func.count(t.c.id)],func.length(t.c.name)==5).execute().fetchone()[0] == 1 + finally: + meta.drop_all() + + def test_server_version_info(self): + version = testing.db.dialect.server_version_info(testing.db.connect()) + assert len(version) == 3, "Got strange version info: %s" % repr(version) + +if __name__ == '__main__': + testenv.main() diff --git a/test/dialect/informix.py b/test/dialect/informix.py new file mode 100644 index 0000000000..1fbbaa0cb4 --- /dev/null +++ b/test/dialect/informix.py @@ -0,0 +1,23 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.databases import informix +from testlib import * + + +class CompileTest(TestBase, AssertsCompiledSQL): + __dialect__ = informix.InfoDialect() + + def test_statements(self): + meta =MetaData() + t1= Table('t1', meta, Column('col1', Integer, primary_key=True), Column('col2', String(50))) + t2= Table('t2', meta, Column('col1', Integer, primary_key=True), Column('col2', String(50)), Column('col3', Integer, ForeignKey('t1.col1'))) + + self.assert_compile(t1.select(), "SELECT t1.col1, t1.col2 FROM t1") + + self.assert_compile(select([t1, t2]).select_from(t1.join(t2)), "SELECT t1.col1, t1.col2, t2.col1, t2.col2, t2.col3 FROM t1 JOIN t2 ON t1.col1 = t2.col3") + + self.assert_compile(t1.update().values({t1.c.col1 : t1.c.col1 + 1}), 'UPDATE t1 SET col1=(t1.col1 + ?)') + + +if __name__ == "__main__": + testenv.main() diff --git a/test/dialect/maxdb.py b/test/dialect/maxdb.py new file mode 100644 index 0000000000..0a35f54705 --- /dev/null +++ b/test/dialect/maxdb.py @@ -0,0 +1,240 @@ +"""MaxDB-specific tests.""" + +import testenv; testenv.configure_for_tests() +import StringIO, sys +from sqlalchemy import * +from sqlalchemy import exceptions, sql +from sqlalchemy.util import Decimal +from sqlalchemy.databases import maxdb +from testlib import * + + +# TODO +# - add "Database" test, a quick check for join behavior on different max versions +# - full max-specific reflection suite +# - datetime tests +# - the orm/query 'test_has' destabilizes the server- cover here + +class ReflectionTest(TestBase, AssertsExecutionResults): + """Extra reflection tests.""" + + __only_on__ = 'maxdb' + + def _test_decimal(self, tabledef): + """Checks a variety of FIXED usages. + + This is primarily for SERIAL columns, which can be FIXED (scale-less) + or (SMALL)INT. Ensures that FIXED id columns are converted to + integers and that are assignable as such. Also exercises general + decimal assignment and selection behavior. + """ + + meta = MetaData(testing.db) + try: + if isinstance(tabledef, basestring): + # run textual CREATE TABLE + testing.db.execute(tabledef) + else: + _t = tabledef.tometadata(meta) + _t.create() + t = Table('dectest', meta, autoload=True) + + vals = [Decimal('2.2'), Decimal('23'), Decimal('2.4'), 25] + cols = ['d1','d2','n1','i1'] + t.insert().execute(dict(zip(cols,vals))) + roundtrip = list(t.select().execute()) + self.assertEquals(roundtrip, [tuple([1] + vals)]) + + t.insert().execute(dict(zip(['id'] + cols, + [2] + list(roundtrip[0][1:])))) + roundtrip2 = list(t.select(order_by=t.c.id).execute()) + self.assertEquals(roundtrip2, [tuple([1] + vals), + tuple([2] + vals)]) + finally: + try: + testing.db.execute("DROP TABLE dectest") + except exceptions.DatabaseError: + pass + + def test_decimal_fixed_serial(self): + tabledef = """ + CREATE TABLE dectest ( + id FIXED(10) DEFAULT SERIAL PRIMARY KEY, + d1 FIXED(10,2), + d2 FIXED(12), + n1 NUMERIC(12,2), + i1 INTEGER) + """ + return self._test_decimal(tabledef) + + def test_decimal_integer_serial(self): + tabledef = """ + CREATE TABLE dectest ( + id INTEGER DEFAULT SERIAL PRIMARY KEY, + d1 DECIMAL(10,2), + d2 DECIMAL(12), + n1 NUMERIC(12,2), + i1 INTEGER) + """ + return self._test_decimal(tabledef) + + def test_decimal_implicit_serial(self): + tabledef = """ + CREATE TABLE dectest ( + id SERIAL PRIMARY KEY, + d1 FIXED(10,2), + d2 FIXED(12), + n1 NUMERIC(12,2), + i1 INTEGER) + """ + return self._test_decimal(tabledef) + + def test_decimal_smallint_serial(self): + tabledef = """ + CREATE TABLE dectest ( + id SMALLINT DEFAULT SERIAL PRIMARY KEY, + d1 FIXED(10,2), + d2 FIXED(12), + n1 NUMERIC(12,2), + i1 INTEGER) + """ + return self._test_decimal(tabledef) + + def test_decimal_sa_types_1(self): + tabledef = Table('dectest', MetaData(), + Column('id', Integer, primary_key=True), + Column('d1', DECIMAL(10, 2)), + Column('d2', DECIMAL(12)), + Column('n1', NUMERIC(12,2)), + Column('i1', Integer)) + return self._test_decimal(tabledef) + + def test_decimal_sa_types_2(self): + tabledef = Table('dectest', MetaData(), + Column('id', Integer, primary_key=True), + Column('d1', maxdb.MaxNumeric(10, 2)), + Column('d2', maxdb.MaxNumeric(12)), + Column('n1', maxdb.MaxNumeric(12,2)), + Column('i1', Integer)) + return self._test_decimal(tabledef) + + def test_decimal_sa_types_3(self): + tabledef = Table('dectest', MetaData(), + Column('id', Integer, primary_key=True), + Column('d1', maxdb.MaxNumeric(10, 2)), + Column('d2', maxdb.MaxNumeric), + Column('n1', maxdb.MaxNumeric(12,2)), + Column('i1', Integer)) + return self._test_decimal(tabledef) + + def test_assorted_type_aliases(self): + """Ensures that aliased types are reflected properly.""" + + meta = MetaData(testing.db) + try: + testing.db.execute(""" + CREATE TABLE assorted ( + c1 INT, + c2 BINARY(2), + c3 DEC(4,2), + c4 DEC(4), + c5 DEC, + c6 DOUBLE PRECISION, + c7 NUMERIC(4,2), + c8 NUMERIC(4), + c9 NUMERIC, + c10 REAL(4), + c11 REAL, + c12 CHARACTER(2)) + """) + table = Table('assorted', meta, autoload=True) + expected = [maxdb.MaxInteger, + maxdb.MaxNumeric, + maxdb.MaxNumeric, + maxdb.MaxNumeric, + maxdb.MaxNumeric, + maxdb.MaxFloat, + maxdb.MaxNumeric, + maxdb.MaxNumeric, + maxdb.MaxNumeric, + maxdb.MaxFloat, + maxdb.MaxFloat, + maxdb.MaxChar,] + for i, col in enumerate(table.columns): + self.assert_(isinstance(col.type, expected[i])) + finally: + try: + testing.db.execute("DROP TABLE assorted") + except exceptions.DatabaseError: + pass + +class DBAPITest(TestBase, AssertsExecutionResults): + """Asserts quirks in the native Python DB-API driver. + + If any of these fail, that's good- the bug is fixed! + """ + + __only_on__ = 'maxdb' + + def test_dbapi_breaks_sequences(self): + con = testing.db.connect().connection + + cr = con.cursor() + cr.execute('CREATE SEQUENCE busto START WITH 1 INCREMENT BY 1') + try: + vals = [] + for i in xrange(3): + cr.execute('SELECT busto.NEXTVAL FROM DUAL') + vals.append(cr.fetchone()[0]) + + # should be 1,2,3, but no... + self.assert_(vals != [1,2,3]) + # ...we get: + self.assert_(vals == [2,4,6]) + finally: + cr.execute('DROP SEQUENCE busto') + + def test_dbapi_breaks_mod_binds(self): + con = testing.db.connect().connection + + cr = con.cursor() + # OK + cr.execute('SELECT MOD(3, 2) FROM DUAL') + + # Broken! + try: + cr.execute('SELECT MOD(3, ?) FROM DUAL', [2]) + self.assert_(False) + except: + self.assert_(True) + + # OK + cr.execute('SELECT MOD(?, 2) FROM DUAL', [3]) + + def test_dbapi_breaks_close(self): + dialect = testing.db.dialect + cargs, ckw = dialect.create_connect_args(testing.db.url) + + # There doesn't seem to be a way to test for this as it occurs in + # regular usage- the warning doesn't seem to go through 'warnings'. + con = dialect.dbapi.connect(*cargs, **ckw) + con.close() + del con # <-- exception during __del__ + + # But this does the same thing. + con = dialect.dbapi.connect(*cargs, **ckw) + self.assert_(con.close == con.__del__) + con.close() + try: + con.close() + self.assert_(False) + except dialect.dbapi.DatabaseError: + self.assert_(True) + + def test_modulo_operator(self): + st = str(select([sql.column('col') % 5]).compile(testing.db)) + self.assertEquals(st, 'SELECT mod(col, ?) FROM DUAL') + + +if __name__ == "__main__": + testenv.main() diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py new file mode 100755 index 0000000000..b5d7f1641b --- /dev/null +++ b/test/dialect/mssql.py @@ -0,0 +1,256 @@ +import testenv; testenv.configure_for_tests() +import re +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy import exceptions +from sqlalchemy.sql import table, column +from sqlalchemy.databases import mssql +from testlib import * + + +class CompileTest(TestBase, AssertsCompiledSQL): + __dialect__ = mssql.MSSQLDialect() + + def test_insert(self): + t = table('sometable', column('somecolumn')) + self.assert_compile(t.insert(), "INSERT INTO sometable (somecolumn) VALUES (:somecolumn)") + + def test_update(self): + t = table('sometable', column('somecolumn')) + self.assert_compile(t.update(t.c.somecolumn==7), "UPDATE sometable SET somecolumn=:somecolumn WHERE sometable.somecolumn = :somecolumn_1", dict(somecolumn=10)) + + def test_count(self): + t = table('sometable', column('somecolumn')) + self.assert_compile(t.count(), "SELECT count(sometable.somecolumn) AS tbl_row_count FROM sometable") + + def test_noorderby_insubquery(self): + """test that the ms-sql dialect removes ORDER BY clauses from subqueries""" + + table1 = table('mytable', + column('myid', Integer), + column('name', String), + column('description', String), + ) + + q = select([table1.c.myid], order_by=[table1.c.myid]).alias('foo') + crit = q.c.myid == table1.c.myid + self.assert_compile(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable) AS foo, mytable WHERE foo.myid = mytable.myid""") + + def test_aliases_schemas(self): + metadata = MetaData() + table1 = table('mytable', + column('myid', Integer), + column('name', String), + column('description', String), + ) + + table4 = Table( + 'remotetable', metadata, + Column('rem_id', Integer, primary_key=True), + Column('datatype_id', Integer), + Column('value', String(20)), + schema = 'remote_owner' + ) + + s = table4.select() + c = s.compile(dialect=self.__dialect__) + assert table4.c.rem_id in set(c.result_map['rem_id'][1]) + + s = table4.select(use_labels=True) + c = s.compile(dialect=self.__dialect__) + print c.result_map + assert table4.c.rem_id in set(c.result_map['remote_owner_remotetable_rem_id'][1]) + + self.assert_compile(table4.select(), "SELECT remotetable_1.rem_id, remotetable_1.datatype_id, remotetable_1.value FROM remote_owner.remotetable AS remotetable_1") + + self.assert_compile(table4.select(use_labels=True), "SELECT remotetable_1.rem_id AS remote_owner_remotetable_rem_id, remotetable_1.datatype_id AS remote_owner_remotetable_datatype_id, remotetable_1.value AS remote_owner_remotetable_value FROM remote_owner.remotetable AS remotetable_1") + + self.assert_compile(table1.join(table4, table1.c.myid==table4.c.rem_id).select(), "SELECT mytable.myid, mytable.name, mytable.description, remotetable_1.rem_id, remotetable_1.datatype_id, remotetable_1.value FROM mytable JOIN remote_owner.remotetable AS remotetable_1 ON remotetable_1.rem_id = mytable.myid") + + def test_delete_schema(self): + metadata = MetaData() + tbl = Table('test', metadata, Column('id', Integer, primary_key=True), schema='paj') + self.assert_compile(tbl.delete(tbl.c.id == 1), "DELETE FROM paj.test WHERE paj.test.id = :id_1") + + def test_union(self): + t1 = table('t1', + column('col1'), + column('col2'), + column('col3'), + column('col4') + ) + t2 = table('t2', + column('col1'), + column('col2'), + column('col3'), + column('col4')) + + (s1, s2) = ( + select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], t1.c.col2.in_(["t1col2r1", "t1col2r2"])), + select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], t2.c.col2.in_(["t2col2r2", "t2col2r3"])) + ) + u = union(s1, s2, order_by=['col3', 'col4']) + self.assert_compile(u, "SELECT t1.col3 AS col3, t1.col4 AS col4 FROM t1 WHERE t1.col2 IN (:col2_1, :col2_2) "\ + "UNION SELECT t2.col3 AS col3, t2.col4 AS col4 FROM t2 WHERE t2.col2 IN (:col2_3, :col2_4) ORDER BY col3, col4") + + self.assert_compile(u.alias('bar').select(), "SELECT bar.col3, bar.col4 FROM (SELECT t1.col3 AS col3, t1.col4 AS col4 FROM t1 WHERE "\ + "t1.col2 IN (:col2_1, :col2_2) UNION SELECT t2.col3 AS col3, t2.col4 AS col4 FROM t2 WHERE t2.col2 IN (:col2_3, :col2_4)) AS bar") + + def test_function(self): + self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)") + self.assert_compile(func.current_time(), "CURRENT_TIME") + self.assert_compile(func.foo(), "foo()") + + m = MetaData() + t = Table('sometable', m, Column('col1', Integer), Column('col2', Integer)) + self.assert_compile(select([func.max(t.c.col1)]), "SELECT max(sometable.col1) AS max_1 FROM sometable") + +class ReflectionTest(TestBase): + __only_on__ = 'mssql' + + def testidentity(self): + meta = MetaData(testing.db) + table = Table( + 'identity_test', meta, + Column('col1', Integer, Sequence('fred', 2, 3), primary_key=True) + ) + table.create() + + meta2 = MetaData(testing.db) + try: + table2 = Table('identity_test', meta2, autoload=True) + assert table2.c['col1'].sequence.start == 2 + assert table2.c['col1'].sequence.increment == 3 + finally: + table.drop() + + +class QueryTest(TestBase): + __only_on__ = 'mssql' + + def test_fetchid_trigger(self): + meta = MetaData(testing.db) + t1 = Table('t1', meta, + Column('id', Integer, Sequence('fred', 100, 1), primary_key=True), + Column('descr', String(200))) + t2 = Table('t2', meta, + Column('id', Integer, Sequence('fred', 200, 1), primary_key=True), + Column('descr', String(200))) + meta.create_all() + con = testing.db.connect() + con.execute("""create trigger paj on t1 for insert as + insert into t2 (descr) select descr from inserted""") + + try: + tr = con.begin() + r = con.execute(t2.insert(), descr='hello') + self.assert_(r.last_inserted_ids() == [200]) + r = con.execute(t1.insert(), descr='hello') + self.assert_(r.last_inserted_ids() == [100]) + + finally: + tr.commit() + con.execute("""drop trigger paj""") + meta.drop_all() + + def test_insertid_schema(self): + meta = MetaData(testing.db) + con = testing.db.connect() + con.execute('create schema paj') + tbl = Table('test', meta, Column('id', Integer, primary_key=True), schema='paj') + tbl.create() + try: + tbl.insert().execute({'id':1}) + finally: + tbl.drop() + con.execute('drop schema paj') + + def test_delete_schema(self): + meta = MetaData(testing.db) + con = testing.db.connect() + con.execute('create schema paj') + tbl = Table('test', meta, Column('id', Integer, primary_key=True), schema='paj') + tbl.create() + try: + tbl.insert().execute({'id':1}) + tbl.delete(tbl.c.id == 1).execute() + finally: + tbl.drop() + con.execute('drop schema paj') + + def test_insertid_reserved(self): + meta = MetaData(testing.db) + table = Table( + 'select', meta, + Column('col', Integer, primary_key=True) + ) + table.create() + + meta2 = MetaData(testing.db) + try: + table.insert().execute(col=7) + finally: + table.drop() + + def test_select_limit_nooffset(self): + metadata = MetaData(testing.db) + + users = Table('query_users', metadata, + Column('user_id', INT, primary_key = True), + Column('user_name', VARCHAR(20)), + ) + addresses = Table('query_addresses', metadata, + Column('address_id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('query_users.user_id')), + Column('address', String(30))) + metadata.create_all() + + try: + try: + r = users.select(limit=3, offset=2, + order_by=[users.c.user_id]).execute().fetchall() + assert False # InvalidRequestError should have been raised + except exceptions.InvalidRequestError: + pass + finally: + metadata.drop_all() + +class Foo(object): + def __init__(self, **kw): + for k in kw: + setattr(self, k, kw[k]) + +class GenerativeQueryTest(TestBase): + __only_on__ = 'mssql' + + def setUpAll(self): + global foo, metadata + metadata = MetaData(testing.db) + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_id_seq'), + primary_key=True), + Column('bar', Integer), + Column('range', Integer)) + + mapper(Foo, foo) + metadata.create_all() + + sess = create_session(bind=testing.db) + for i in range(100): + sess.save(Foo(bar=i, range=i%10)) + sess.flush() + + def tearDownAll(self): + metadata.drop_all() + clear_mappers() + + def test_slice_mssql(self): + sess = create_session(bind=testing.db) + query = sess.query(Foo) + orig = query.all() + assert list(query[:10]) == orig[:10] + assert list(query[:10]) == orig[:10] + + +if __name__ == "__main__": + testenv.main() diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index dbba78893d..00478908ef 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -1,16 +1,52 @@ -import testbase +import testenv; testenv.configure_for_tests() +import sets from sqlalchemy import * +from sqlalchemy import sql, exceptions from sqlalchemy.databases import mysql from testlib import * -class TypesTest(AssertMixin): +class TypesTest(TestBase, AssertsExecutionResults): "Test MySQL column types" - @testing.supported('mysql') + __only_on__ = 'mysql' + + def test_basic(self): + meta1 = MetaData(testing.db) + table = Table( + 'mysql_types', meta1, + Column('id', Integer, primary_key=True), + Column('num1', mysql.MSInteger(unsigned=True)), + Column('text1', mysql.MSLongText), + Column('text2', mysql.MSLongText()), + Column('num2', mysql.MSBigInteger), + Column('num3', mysql.MSBigInteger()), + Column('num4', mysql.MSDouble), + Column('num5', mysql.MSDouble()), + Column('enum1', mysql.MSEnum("'black'", "'white'")), + ) + try: + table.drop(checkfirst=True) + table.create() + meta2 = MetaData(testing.db) + t2 = Table('mysql_types', meta2, autoload=True) + assert isinstance(t2.c.num1.type, mysql.MSInteger) + assert t2.c.num1.type.unsigned + assert isinstance(t2.c.text1.type, mysql.MSLongText) + assert isinstance(t2.c.text2.type, mysql.MSLongText) + assert isinstance(t2.c.num2.type, mysql.MSBigInteger) + assert isinstance(t2.c.num3.type, mysql.MSBigInteger) + assert isinstance(t2.c.num4.type, mysql.MSDouble) + assert isinstance(t2.c.num5.type, mysql.MSDouble) + assert isinstance(t2.c.enum1.type, mysql.MSEnum) + t2.drop() + t2.create() + finally: + meta1.drop_all() + def test_numeric(self): "Exercise type specification and options for numeric types." - + columns = [ # column type, args, kwargs, expected ddl # e.g. Column(Integer(10, unsigned=True)) == 'INTEGER(10) UNSIGNED' @@ -44,8 +80,6 @@ class TypesTest(AssertMixin): (mysql.MSDouble, [None, None], {}, 'DOUBLE'), - (mysql.MSDouble, [12], {}, - 'DOUBLE(12, 2)'), (mysql.MSDouble, [12, 4], {'unsigned':True}, 'DOUBLE(12, 4) UNSIGNED'), (mysql.MSDouble, [12, 4], {'zerofill':True}, @@ -53,8 +87,17 @@ class TypesTest(AssertMixin): (mysql.MSDouble, [12, 4], {'zerofill':True, 'unsigned':True}, 'DOUBLE(12, 4) UNSIGNED ZEROFILL'), + (mysql.MSReal, [None, None], {}, + 'REAL'), + (mysql.MSReal, [12, 4], {'unsigned':True}, + 'REAL(12, 4) UNSIGNED'), + (mysql.MSReal, [12, 4], {'zerofill':True}, + 'REAL(12, 4) ZEROFILL'), + (mysql.MSReal, [12, 4], {'zerofill':True, 'unsigned':True}, + 'REAL(12, 4) UNSIGNED ZEROFILL'), + (mysql.MSFloat, [], {}, - 'FLOAT(10)'), + 'FLOAT'), (mysql.MSFloat, [None], {}, 'FLOAT'), (mysql.MSFloat, [12], {}, @@ -90,6 +133,17 @@ class TypesTest(AssertMixin): (mysql.MSBigInteger, [4], {'zerofill':True, 'unsigned':True}, 'BIGINT(4) UNSIGNED ZEROFILL'), + (mysql.MSTinyInteger, [], {}, + 'TINYINT'), + (mysql.MSTinyInteger, [1], {}, + 'TINYINT(1)'), + (mysql.MSTinyInteger, [1], {'unsigned':True}, + 'TINYINT(1) UNSIGNED'), + (mysql.MSTinyInteger, [1], {'zerofill':True}, + 'TINYINT(1) ZEROFILL'), + (mysql.MSTinyInteger, [1], {'zerofill':True, 'unsigned':True}, + 'TINYINT(1) UNSIGNED ZEROFILL'), + (mysql.MSSmallInteger, [], {}, 'SMALLINT'), (mysql.MSSmallInteger, [4], {}, @@ -102,18 +156,19 @@ class TypesTest(AssertMixin): 'SMALLINT(4) UNSIGNED ZEROFILL'), ] - table_args = ['test_mysql_numeric', MetaData(testbase.db)] + table_args = ['test_mysql_numeric', MetaData(testing.db)] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw))) numeric_table = Table(*table_args) - gen = testbase.db.dialect.schemagenerator(testbase.db, None, None) - + gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + for col in numeric_table.c: index = int(col.name[1:]) - self.assertEquals(gen.get_column_specification(col), - "%s %s" % (col.name, columns[index][3])) + self.assert_eq(gen.get_column_specification(col), + "%s %s" % (col.name, columns[index][3])) + self.assert_(repr(col)) try: numeric_table.create(checkfirst=True) @@ -121,11 +176,10 @@ class TypesTest(AssertMixin): except: raise numeric_table.drop() - - @testing.supported('mysql') + + @testing.exclude('mysql', '<', (4, 1, 1)) def test_charset(self): - """Exercise CHARACTER SET and COLLATE-related options on string-type - columns.""" + """Exercise CHARACTER SET and COLLATE-ish options on string types.""" columns = [ (mysql.MSChar, [1], {}, @@ -177,7 +231,7 @@ class TypesTest(AssertMixin): 'TINYTEXT CHARACTER SET utf8 COLLATE utf8_bin'), (mysql.MSMediumText, [], {'charset':'utf8', 'binary':True}, - 'MEDIUMTEXT CHARACTER SET utf8 BINARY'), + 'MEDIUMTEXT CHARACTER SET utf8 BINARY'), (mysql.MSLongText, [], {'ascii':True}, 'LONGTEXT ASCII'), @@ -186,18 +240,19 @@ class TypesTest(AssertMixin): '''ENUM('foo','bar') UNICODE''') ] - table_args = ['test_mysql_charset', MetaData(testbase.db)] + table_args = ['test_mysql_charset', MetaData(testing.db)] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw))) charset_table = Table(*table_args) - gen = testbase.db.dialect.schemagenerator(testbase.db, None, None) - + gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + for col in charset_table.c: index = int(col.name[1:]) - self.assertEquals(gen.get_column_specification(col), - "%s %s" % (col.name, columns[index][3])) + self.assert_eq(gen.get_column_specification(col), + "%s %s" % (col.name, columns[index][3])) + self.assert_(repr(col)) try: charset_table.create(checkfirst=True) @@ -206,22 +261,276 @@ class TypesTest(AssertMixin): raise charset_table.drop() - @testing.supported('mysql') + @testing.exclude('mysql', '<', (5, 0, 5)) + def test_bit_50(self): + """Exercise BIT types on 5.0+ (not valid for all engine types)""" + + meta = MetaData(testing.db) + bit_table = Table('mysql_bits', meta, + Column('b1', mysql.MSBit), + Column('b2', mysql.MSBit()), + Column('b3', mysql.MSBit(), nullable=False), + Column('b4', mysql.MSBit(1)), + Column('b5', mysql.MSBit(8)), + Column('b6', mysql.MSBit(32)), + Column('b7', mysql.MSBit(63)), + Column('b8', mysql.MSBit(64))) + + self.assert_eq(colspec(bit_table.c.b1), 'b1 BIT') + self.assert_eq(colspec(bit_table.c.b2), 'b2 BIT') + self.assert_eq(colspec(bit_table.c.b3), 'b3 BIT NOT NULL') + self.assert_eq(colspec(bit_table.c.b4), 'b4 BIT(1)') + self.assert_eq(colspec(bit_table.c.b5), 'b5 BIT(8)') + self.assert_eq(colspec(bit_table.c.b6), 'b6 BIT(32)') + self.assert_eq(colspec(bit_table.c.b7), 'b7 BIT(63)') + self.assert_eq(colspec(bit_table.c.b8), 'b8 BIT(64)') + + for col in bit_table.c: + self.assert_(repr(col)) + try: + meta.create_all() + + meta2 = MetaData(testing.db) + reflected = Table('mysql_bits', meta2, autoload=True) + + for table in bit_table, reflected: + + def roundtrip(store, expected=None): + expected = expected or store + table.insert(store).execute() + row = list(table.select().execute())[0] + try: + self.assert_(list(row) == expected) + except: + print "Storing %s" % store + print "Expected %s" % expected + print "Found %s" % list(row) + raise + table.delete().execute() + + roundtrip([0] * 8) + roundtrip([None, None, 0, None, None, None, None, None]) + roundtrip([1] * 8) + roundtrip([sql.text("b'1'")] * 8, [1] * 8) + + i = 255 + roundtrip([0, 0, 0, 0, i, i, i, i]) + i = 2**32 - 1 + roundtrip([0, 0, 0, 0, 0, i, i, i]) + i = 2**63 - 1 + roundtrip([0, 0, 0, 0, 0, 0, i, i]) + i = 2**64 - 1 + roundtrip([0, 0, 0, 0, 0, 0, 0, i]) + finally: + meta.drop_all() + + def test_boolean(self): + """Test BOOL/TINYINT(1) compatability and reflection.""" + + meta = MetaData(testing.db) + bool_table = Table('mysql_bool', meta, + Column('b1', BOOLEAN), + Column('b2', mysql.MSBoolean), + Column('b3', mysql.MSTinyInteger(1)), + Column('b4', mysql.MSTinyInteger)) + + self.assert_eq(colspec(bool_table.c.b1), 'b1 BOOL') + self.assert_eq(colspec(bool_table.c.b2), 'b2 BOOL') + self.assert_eq(colspec(bool_table.c.b3), 'b3 TINYINT(1)') + self.assert_eq(colspec(bool_table.c.b4), 'b4 TINYINT') + + for col in bool_table.c: + self.assert_(repr(col)) + try: + meta.create_all() + + table = bool_table + def roundtrip(store, expected=None): + expected = expected or store + table.insert(store).execute() + row = list(table.select().execute())[0] + try: + self.assert_(list(row) == expected) + for i, val in enumerate(expected): + if isinstance(val, bool): + self.assert_(val is row[i]) + except: + print "Storing %s" % store + print "Expected %s" % expected + print "Found %s" % list(row) + raise + table.delete().execute() + + + roundtrip([None, None, None, None]) + roundtrip([True, True, 1, 1]) + roundtrip([False, False, 0, 0]) + roundtrip([True, True, True, True], [True, True, 1, 1]) + roundtrip([False, False, 0, 0], [False, False, 0, 0]) + + meta2 = MetaData(testing.db) + # replace with reflected + table = Table('mysql_bool', meta2, autoload=True) + self.assert_eq(colspec(table.c.b3), 'b3 BOOL') + + roundtrip([None, None, None, None]) + roundtrip([True, True, 1, 1], [True, True, True, 1]) + roundtrip([False, False, 0, 0], [False, False, False, 0]) + roundtrip([True, True, True, True], [True, True, True, 1]) + roundtrip([False, False, 0, 0], [False, False, False, 0]) + finally: + meta.drop_all() + + @testing.exclude('mysql', '<', (4, 1, 0)) + def test_timestamp(self): + """Exercise funky TIMESTAMP default syntax.""" + + meta = MetaData(testing.db) + + try: + columns = [ + ([TIMESTAMP], + 'TIMESTAMP'), + ([mysql.MSTimeStamp], + 'TIMESTAMP'), + ([mysql.MSTimeStamp, + PassiveDefault(sql.text('CURRENT_TIMESTAMP'))], + "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), + ([mysql.MSTimeStamp, + PassiveDefault(sql.text("'1999-09-09 09:09:09'"))], + "TIMESTAMP DEFAULT '1999-09-09 09:09:09'"), + ([mysql.MSTimeStamp, + PassiveDefault(sql.text("'1999-09-09 09:09:09' " + "ON UPDATE CURRENT_TIMESTAMP"))], + "TIMESTAMP DEFAULT '1999-09-09 09:09:09' " + "ON UPDATE CURRENT_TIMESTAMP"), + ([mysql.MSTimeStamp, + PassiveDefault(sql.text("CURRENT_TIMESTAMP " + "ON UPDATE CURRENT_TIMESTAMP"))], + "TIMESTAMP DEFAULT CURRENT_TIMESTAMP " + "ON UPDATE CURRENT_TIMESTAMP"), + ] + for idx, (spec, expected) in enumerate(columns): + t = Table('mysql_ts%s' % idx, meta, + Column('id', Integer, primary_key=True), + Column('t', *spec)) + self.assert_eq(colspec(t.c.t), "t %s" % expected) + self.assert_(repr(t.c.t)) + t.create() + r = Table('mysql_ts%s' % idx, MetaData(testing.db), + autoload=True) + if len(spec) > 1: + self.assert_(r.c.t is not None) + finally: + meta.drop_all() + + def test_year(self): + """Exercise YEAR.""" + + meta = MetaData(testing.db) + year_table = Table('mysql_year', meta, + Column('y1', mysql.MSYear), + Column('y2', mysql.MSYear), + Column('y3', mysql.MSYear), + Column('y4', mysql.MSYear(2)), + Column('y5', mysql.MSYear(4))) + + for col in year_table.c: + self.assert_(repr(col)) + try: + year_table.create() + reflected = Table('mysql_year', MetaData(testing.db), + autoload=True) + + for table in year_table, reflected: + table.insert(['1950', '50', None, 50, 1950]).execute() + row = list(table.select().execute())[0] + self.assert_eq(list(row), [1950, 2050, None, 50, 1950]) + table.delete().execute() + self.assert_(colspec(table.c.y1).startswith('y1 YEAR')) + self.assert_eq(colspec(table.c.y4), 'y4 YEAR(2)') + self.assert_eq(colspec(table.c.y5), 'y5 YEAR(4)') + finally: + meta.drop_all() + + + def test_set(self): + """Exercise the SET type.""" + + meta = MetaData(testing.db) + set_table = Table('mysql_set', meta, + Column('s1', mysql.MSSet("'dq'", "'sq'")), + Column('s2', mysql.MSSet("'a'")), + Column('s3', mysql.MSSet("'5'", "'7'", "'9'"))) + + self.assert_eq(colspec(set_table.c.s1), "s1 SET('dq','sq')") + self.assert_eq(colspec(set_table.c.s2), "s2 SET('a')") + self.assert_eq(colspec(set_table.c.s3), "s3 SET('5','7','9')") + + for col in set_table.c: + self.assert_(repr(col)) + try: + set_table.create() + reflected = Table('mysql_set', MetaData(testing.db), + autoload=True) + + for table in set_table, reflected: + def roundtrip(store, expected=None): + expected = expected or store + table.insert(store).execute() + row = list(table.select().execute())[0] + try: + self.assert_(list(row) == expected) + except: + print "Storing %s" % store + print "Expected %s" % expected + print "Found %s" % list(row) + raise + table.delete().execute() + + roundtrip([None, None, None],[None] * 3) + roundtrip(['', '', ''], [set([''])] * 3) + + roundtrip([set(['dq']), set(['a']), set(['5'])]) + roundtrip(['dq', 'a', '5'], + [set(['dq']), set(['a']), set(['5'])]) + roundtrip([1, 1, 1], + [set(['dq']), set(['a']), set(['5'])]) + roundtrip([set(['dq', 'sq']), None, set(['9', '5', '7'])]) + + set_table.insert().execute({'s3':set(['5'])}, + {'s3':set(['5', '7'])}, + {'s3':set(['5', '7', '9'])}, + {'s3':set(['7', '9'])}) + rows = list(select( + [set_table.c.s3], + set_table.c.s3.in_([set(['5']), set(['5', '7'])])).execute()) + found = set([frozenset(row[0]) for row in rows]) + self.assertEquals(found, + set([frozenset(['5']), frozenset(['5', '7'])])) + finally: + meta.drop_all() + def test_enum(self): - "Exercise the ENUM type" - - db = testbase.db - enum_table = Table('mysql_enum', MetaData(testbase.db), - Column('e1', mysql.MSEnum('"a"', "'b'")), - Column('e2', mysql.MSEnum('"a"', "'b'"), nullable=False), - Column('e3', mysql.MSEnum('"a"', "'b'", strict=True)), - Column('e4', mysql.MSEnum('"a"', "'b'", strict=True), nullable=False)) - spec = lambda c: db.dialect.schemagenerator(db, None, None).get_column_specification(c) - - self.assertEqual(spec(enum_table.c.e1), """e1 ENUM("a",'b')""") - self.assertEqual(spec(enum_table.c.e2), """e2 ENUM("a",'b') NOT NULL""") - self.assertEqual(spec(enum_table.c.e3), """e3 ENUM("a",'b')""") - self.assertEqual(spec(enum_table.c.e4), """e4 ENUM("a",'b') NOT NULL""") + """Exercise the ENUM type.""" + + db = testing.db + enum_table = Table('mysql_enum', MetaData(testing.db), + Column('e1', mysql.MSEnum("'a'", "'b'")), + Column('e2', mysql.MSEnum("'a'", "'b'"), + nullable=False), + Column('e3', mysql.MSEnum("'a'", "'b'", strict=True)), + Column('e4', mysql.MSEnum("'a'", "'b'", strict=True), + nullable=False)) + + self.assert_eq(colspec(enum_table.c.e1), + "e1 ENUM('a','b')") + self.assert_eq(colspec(enum_table.c.e2), + "e2 ENUM('a','b') NOT NULL") + self.assert_eq(colspec(enum_table.c.e3), + "e3 ENUM('a','b')") + self.assert_eq(colspec(enum_table.c.e4), + "e4 ENUM('a','b') NOT NULL") enum_table.drop(checkfirst=True) enum_table.create() @@ -250,10 +559,10 @@ class TypesTest(AssertMixin): # This is known to fail with MySQLDB 1.2.2 beta versions # which return these as sets.Set(['a']), sets.Set(['b']) # (even on Pythons with __builtin__.set) - if testbase.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \ - testbase.db.dialect.dbapi.version_info >= (1, 2, 2): + if testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \ + testing.db.dialect.dbapi.version_info >= (1, 2, 2): # these mysqldb seem to always uses 'sets', even on later pythons - import sets + import sets def convert(value): if value is None: return value @@ -261,32 +570,81 @@ class TypesTest(AssertMixin): return sets.Set([]) else: return sets.Set([value]) - + e = [] for row in expected: e.append(tuple([convert(c) for c in row])) expected = e - self.assertEqual(res, expected) + self.assert_eq(res, expected) enum_table.drop() - @testing.supported('mysql') - def test_type_reflection(self): - # FIXME: older versions need their own test - if testbase.db.dialect.get_version_info(testbase.db) < (5, 0): - return + @testing.exclude('mysql', '>', (3)) + def test_enum_parse(self): + """More exercises for the ENUM type.""" + + # MySQL 3.23 can't handle an ENUM of ''.... + enum_table = Table('mysql_enum', MetaData(testing.db), + Column('e1', mysql.MSEnum("'a'")), + Column('e2', mysql.MSEnum("''")), + Column('e3', mysql.MSEnum("'a'", "''")), + Column('e4', mysql.MSEnum("''", "'a'")), + Column('e5', mysql.MSEnum("''", "'''a'''", "'b''b'", "''''"))) + + for col in enum_table.c: + self.assert_(repr(col)) + try: + enum_table.create() + reflected = Table('mysql_enum', MetaData(testing.db), + autoload=True) + for t in enum_table, reflected: + assert t.c.e1.type.enums == ["a"] + assert t.c.e2.type.enums == [""] + assert t.c.e3.type.enums == ["a", ""] + assert t.c.e4.type.enums == ["", "a"] + assert t.c.e5.type.enums == ["", "'a'", "b'b", "'"] + finally: + enum_table.drop() + + def test_default_reflection(self): + """Test reflection of column defaults.""" + + def_table = Table('mysql_def', MetaData(testing.db), + Column('c1', String(10), PassiveDefault('')), + Column('c2', String(10), PassiveDefault('0')), + Column('c3', String(10), PassiveDefault('abc'))) + + try: + def_table.create() + reflected = Table('mysql_def', MetaData(testing.db), + autoload=True) + for t in def_table, reflected: + assert t.c.c1.default.arg == '' + assert t.c.c2.default.arg == '0' + assert t.c.c3.default.arg == 'abc' + finally: + def_table.drop() + + @testing.exclude('mysql', '<', (5, 0, 0)) + @testing.uses_deprecated('Using String type with no length') + def test_type_reflection(self): # (ask_for, roundtripped_as_if_different) specs = [( String(), mysql.MSText(), ), ( String(1), mysql.MSString(1), ), ( String(3), mysql.MSString(3), ), + ( Text(), mysql.MSText(), ), + ( Unicode(), mysql.MSText(), ), + ( Unicode(1), mysql.MSString(1), ), + ( Unicode(3), mysql.MSString(3), ), + ( UnicodeText(), mysql.MSText(), ), ( mysql.MSChar(1), ), ( mysql.MSChar(3), ), ( NCHAR(2), mysql.MSChar(2), ), ( mysql.MSNChar(2), mysql.MSChar(2), ), # N is CREATE only ( mysql.MSNVarChar(22), mysql.MSString(22), ), - ( Smallinteger(), mysql.MSSmallInteger(), ), - ( Smallinteger(4), mysql.MSSmallInteger(4), ), + ( SmallInteger(), mysql.MSSmallInteger(), ), + ( SmallInteger(4), mysql.MSSmallInteger(4), ), ( mysql.MSSmallInteger(), ), ( mysql.MSSmallInteger(4), mysql.MSSmallInteger(4), ), ( Binary(3), mysql.MSBlob(3), ), @@ -299,26 +657,322 @@ class TypesTest(AssertMixin): ( mysql.MSBlob(1234), mysql.MSBlob()), ( mysql.MSMediumBlob(),), ( mysql.MSLongBlob(),), + ( mysql.MSEnum("''","'fleem'"), ), ] columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)] - m = MetaData(testbase.db) + db = testing.db + m = MetaData(db) t_table = Table('mysql_types', m, *columns) - m.drop_all() - m.create_all() - - m2 = MetaData(testbase.db) - rt = Table('mysql_types', m2, autoload=True) - - #print - expected = [len(c) > 1 and c[1] or c[0] for c in specs] - for i, reflected in enumerate(rt.c): - #print (reflected, specs[i][0], '->', - # reflected.type, '==', expected[i]) - assert isinstance(reflected.type, type(expected[i])) - - m.drop_all() + try: + m.create_all() + + m2 = MetaData(db) + rt = Table('mysql_types', m2, autoload=True) + try: + db.execute('CREATE OR REPLACE VIEW mysql_types_v ' + 'AS SELECT * from mysql_types') + rv = Table('mysql_types_v', m2, autoload=True) + + expected = [len(c) > 1 and c[1] or c[0] for c in specs] + + # Early 5.0 releases seem to report more "general" for columns + # in a view, e.g. char -> varchar, tinyblob -> mediumblob + # + # Not sure exactly which point version has the fix. + if db.dialect.server_version_info(db.connect()) < (5, 0, 11): + tables = rt, + else: + tables = rt, rv + + for table in tables: + for i, reflected in enumerate(table.c): + assert isinstance(reflected.type, type(expected[i])) + finally: + db.execute('DROP VIEW mysql_types_v') + finally: + m.drop_all() + + def test_autoincrement(self): + meta = MetaData(testing.db) + try: + Table('ai_1', meta, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True)) + Table('ai_2', meta, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True)) + Table('ai_3', meta, + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True, autoincrement=False), + Column('int_y', Integer, primary_key=True)) + Table('ai_4', meta, + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True, autoincrement=False), + Column('int_n2', Integer, PassiveDefault('0'), + primary_key=True, autoincrement=False)) + Table('ai_5', meta, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, PassiveDefault('0'), + primary_key=True, autoincrement=False)) + Table('ai_6', meta, + Column('o1', String(1), PassiveDefault('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_7', meta, + Column('o1', String(1), PassiveDefault('x'), + primary_key=True), + Column('o2', String(1), PassiveDefault('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_8', meta, + Column('o1', String(1), PassiveDefault('x'), + primary_key=True), + Column('o2', String(1), PassiveDefault('x'), + primary_key=True)) + meta.create_all() + + table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', + 'ai_5', 'ai_6', 'ai_7', 'ai_8'] + mr = MetaData(testing.db) + mr.reflect(only=table_names) + + for tbl in [mr.tables[name] for name in table_names]: + for c in tbl.c: + if c.name.startswith('int_y'): + assert c.autoincrement + elif c.name.startswith('int_n'): + assert not c.autoincrement + tbl.insert().execute() + if 'int_y' in tbl.c: + assert select([tbl.c.int_y]).scalar() == 1 + assert list(tbl.select().execute().fetchone()).count(1) == 1 + else: + assert 1 not in list(tbl.select().execute().fetchone()) + finally: + meta.drop_all() + + def assert_eq(self, got, wanted): + if got != wanted: + print "Expected %s" % wanted + print "Found %s" % got + self.assertEqual(got, wanted) + + +class SQLTest(TestBase, AssertsCompiledSQL): + """Tests MySQL-dialect specific compilation.""" + + __dialect__ = mysql.dialect() + + def test_precolumns(self): + dialect = self.__dialect__ + + def gen(distinct=None, prefixes=None): + kw = {} + if distinct is not None: + kw['distinct'] = distinct + if prefixes is not None: + kw['prefixes'] = prefixes + return str(select(['q'], **kw).compile(dialect=dialect)) + + self.assertEqual(gen(None), 'SELECT q') + self.assertEqual(gen(True), 'SELECT DISTINCT q') + self.assertEqual(gen(1), 'SELECT DISTINCT q') + self.assertEqual(gen('diSTInct'), 'SELECT DISTINCT q') + self.assertEqual(gen('DISTINCT'), 'SELECT DISTINCT q') + + # Standard SQL + self.assertEqual(gen('all'), 'SELECT ALL q') + self.assertEqual(gen('distinctrow'), 'SELECT DISTINCTROW q') + + # Interaction with MySQL prefix extensions + self.assertEqual( + gen(None, ['straight_join']), + 'SELECT straight_join q') + self.assertEqual( + gen('all', ['HIGH_PRIORITY SQL_SMALL_RESULT']), + 'SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL q') + self.assertEqual( + gen(True, ['high_priority', sql.text('sql_cache')]), + 'SELECT high_priority sql_cache DISTINCT q') + + def test_limit(self): + t = sql.table('t', sql.column('col1'), sql.column('col2')) + + self.assert_compile( + select([t]).limit(10).offset(20), + "SELECT t.col1, t.col2 FROM t LIMIT 20, 10" + ) + self.assert_compile( + select([t]).limit(10), + "SELECT t.col1, t.col2 FROM t LIMIT 10") + self.assert_compile( + select([t]).offset(10), + "SELECT t.col1, t.col2 FROM t LIMIT 10, 18446744073709551615" + ) + + def test_update_limit(self): + t = sql.table('t', sql.column('col1'), sql.column('col2')) + + self.assert_compile( + t.update(values={'col1':123}), + "UPDATE t SET col1=%s" + ) + self.assert_compile( + t.update(values={'col1':123}, mysql_limit=5), + "UPDATE t SET col1=%s LIMIT 5" + ) + self.assert_compile( + t.update(values={'col1':123}, mysql_limit=None), + "UPDATE t SET col1=%s" + ) + self.assert_compile( + t.update(t.c.col2==456, values={'col1':123}, mysql_limit=1), + "UPDATE t SET col1=%s WHERE t.col2 = %s LIMIT 1" + ) + + def test_cast(self): + t = sql.table('t', sql.column('col')) + m = mysql + + specs = [ + (Integer, "CAST(t.col AS SIGNED INTEGER)"), + (INT, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSInteger, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"), + (SmallInteger, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSSmallInteger, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSTinyInteger, "CAST(t.col AS SIGNED INTEGER)"), + # 'SIGNED INTEGER' is a bigint, so this is ok. + (m.MSBigInteger, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSBigInteger(unsigned=False), "CAST(t.col AS SIGNED INTEGER)"), + (m.MSBigInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"), + (m.MSBit, "t.col"), + + # this is kind of sucky. thank you default arguments! + (NUMERIC, "CAST(t.col AS DECIMAL(10, 2))"), + (DECIMAL, "CAST(t.col AS DECIMAL(10, 2))"), + (Numeric, "CAST(t.col AS DECIMAL(10, 2))"), + (m.MSNumeric, "CAST(t.col AS DECIMAL(10, 2))"), + (m.MSDecimal, "CAST(t.col AS DECIMAL(10, 2))"), + + (FLOAT, "t.col"), + (Float, "t.col"), + (m.MSFloat, "t.col"), + (m.MSDouble, "t.col"), + (m.MSReal, "t.col"), + + (TIMESTAMP, "CAST(t.col AS DATETIME)"), + (DATETIME, "CAST(t.col AS DATETIME)"), + (DATE, "CAST(t.col AS DATE)"), + (TIME, "CAST(t.col AS TIME)"), + (DateTime, "CAST(t.col AS DATETIME)"), + (Date, "CAST(t.col AS DATE)"), + (Time, "CAST(t.col AS TIME)"), + (m.MSDateTime, "CAST(t.col AS DATETIME)"), + (m.MSDate, "CAST(t.col AS DATE)"), + (m.MSTime, "CAST(t.col AS TIME)"), + (m.MSTimeStamp, "CAST(t.col AS DATETIME)"), + (m.MSYear, "t.col"), + (m.MSYear(2), "t.col"), + (Interval, "t.col"), + + (String, "CAST(t.col AS CHAR)"), + (Unicode, "CAST(t.col AS CHAR)"), + (UnicodeText, "CAST(t.col AS CHAR)"), + (VARCHAR, "CAST(t.col AS CHAR)"), + (NCHAR, "CAST(t.col AS CHAR)"), + (CHAR, "CAST(t.col AS CHAR)"), + (CLOB, "CAST(t.col AS CHAR)"), + (TEXT, "CAST(t.col AS CHAR)"), + (String(32), "CAST(t.col AS CHAR(32))"), + (Unicode(32), "CAST(t.col AS CHAR(32))"), + (CHAR(32), "CAST(t.col AS CHAR(32))"), + (m.MSString, "CAST(t.col AS CHAR)"), + (m.MSText, "CAST(t.col AS CHAR)"), + (m.MSTinyText, "CAST(t.col AS CHAR)"), + (m.MSMediumText, "CAST(t.col AS CHAR)"), + (m.MSLongText, "CAST(t.col AS CHAR)"), + (m.MSNChar, "CAST(t.col AS CHAR)"), + (m.MSNVarChar, "CAST(t.col AS CHAR)"), + + (Binary, "CAST(t.col AS BINARY)"), + (BLOB, "CAST(t.col AS BINARY)"), + (m.MSBlob, "CAST(t.col AS BINARY)"), + (m.MSBlob(32), "CAST(t.col AS BINARY)"), + (m.MSTinyBlob, "CAST(t.col AS BINARY)"), + (m.MSMediumBlob, "CAST(t.col AS BINARY)"), + (m.MSLongBlob, "CAST(t.col AS BINARY)"), + (m.MSBinary, "CAST(t.col AS BINARY)"), + (m.MSBinary(32), "CAST(t.col AS BINARY)"), + (m.MSVarBinary, "CAST(t.col AS BINARY)"), + (m.MSVarBinary(32), "CAST(t.col AS BINARY)"), + + # maybe this could be changed to something more DWIM, needs + # testing + (Boolean, "t.col"), + (BOOLEAN, "t.col"), + (m.MSBoolean, "t.col"), + + (m.MSEnum, "t.col"), + (m.MSEnum("'1'", "'2'"), "t.col"), + (m.MSSet, "t.col"), + (m.MSSet("'1'", "'2'"), "t.col"), + ] + + for type_, expected in specs: + self.assert_compile(cast(t.c.col, type_), expected) + + +class ExecutionTest(TestBase): + """Various MySQL execution special cases.""" + + __only_on__ = 'mysql' + + def test_charset_caching(self): + engine = engines.testing_engine() + + cx = engine.connect() + meta = MetaData() + + assert ('mysql', 'charset') not in cx.info + assert ('mysql', 'force_charset') not in cx.info + + cx.execute(text("SELECT 1")).fetchall() + assert ('mysql', 'charset') not in cx.info + + meta.reflect(cx) + assert ('mysql', 'charset') in cx.info + + cx.execute(text("SET @squiznart=123")) + assert ('mysql', 'charset') in cx.info + + # the charset invalidation is very conservative + cx.execute(text("SET TIMESTAMP = DEFAULT")) + assert ('mysql', 'charset') not in cx.info + + cx.info[('mysql', 'force_charset')] = 'latin1' + + assert engine.dialect._detect_charset(cx) == 'latin1' + assert cx.info[('mysql', 'charset')] == 'latin1' + + del cx.info[('mysql', 'force_charset')] + del cx.info[('mysql', 'charset')] + + meta.reflect(cx) + assert ('mysql', 'charset') in cx.info + + # String execution doesn't go through the detector. + cx.execute("SET TIMESTAMP = DEFAULT") + assert ('mysql', 'charset') in cx.info + + +def colspec(c): + return testing.db.dialect.schemagenerator(testing.db.dialect, + testing.db, None, None).get_column_specification(c) if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/dialect/oracle.py b/test/dialect/oracle.py index 14de8960b4..cdd575dd38 100644 --- a/test/dialect/oracle.py +++ b/test/dialect/oracle.py @@ -1,14 +1,15 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * -from sqlalchemy.databases import mysql - +from sqlalchemy.sql import table, column +from sqlalchemy.databases import oracle from testlib import * -class OutParamTest(AssertMixin): - @testing.supported('oracle') +class OutParamTest(TestBase, AssertsExecutionResults): + __only_on__ = 'oracle' + def setUpAll(self): - testbase.db.execute(""" + testing.db.execute(""" create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT number) IS retval number; begin @@ -18,15 +19,265 @@ create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT numb end; """) - @testing.supported('oracle') def test_out_params(self): - result = testbase.db.execute(text("begin foo(:x, :y, :z); end;", bindparams=[bindparam('x', Numeric), outparam('y', Numeric), outparam('z', Numeric)]), x=5) + result = testing.db.execute(text("begin foo(:x, :y, :z); end;", bindparams=[bindparam('x', Numeric), outparam('y', Numeric), outparam('z', Numeric)]), x=5) assert result.out_parameters == {'y':10, 'z':75}, result.out_parameters print result.out_parameters - @testing.supported('oracle') def tearDownAll(self): - testbase.db.execute("DROP PROCEDURE foo") + testing.db.execute("DROP PROCEDURE foo") + + +class CompileTest(TestBase, AssertsCompiledSQL): + __dialect__ = oracle.OracleDialect() + + def test_owner(self): + meta = MetaData() + parent = Table('parent', meta, Column('id', Integer, primary_key=True), + Column('name', String(50)), + owner='ed') + child = Table('child', meta, Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('ed.parent.id')), + owner = 'ed') + + self.assert_compile(parent.join(child), "ed.parent JOIN ed.child ON ed.parent.id = ed.child.parent_id") + + def test_subquery(self): + t = table('sometable', column('col1'), column('col2')) + s = select([t]) + s = select([s.c.col1, s.c.col2]) + + self.assert_compile(s, "SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, sometable.col2 AS col2 FROM sometable)") + + def test_limit(self): + t = table('sometable', column('col1'), column('col2')) + + s = select([t]) + c = s.compile(dialect=oracle.OracleDialect()) + assert t.c.col1 in set(c.result_map['col1'][1]) + + s = select([t]).limit(10).offset(20) + + self.assert_compile(s, "SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, sometable.col2 AS col2, " + "ROW_NUMBER() OVER (ORDER BY sometable.rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30" + ) + + # assert that despite the subquery, the columns from the table, + # not the select, get put into the "result_map" + c = s.compile(dialect=oracle.OracleDialect()) + assert t.c.col1 in set(c.result_map['col1'][1]) + + s = select([s.c.col1, s.c.col2]) + + self.assert_compile(s, "SELECT col1, col2 FROM (SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, " + "sometable.col2 AS col2, ROW_NUMBER() OVER (ORDER BY sometable.rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30)") + + # testing this twice to ensure oracle doesn't modify the original statement + self.assert_compile(s, "SELECT col1, col2 FROM (SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, " + "sometable.col2 AS col2, ROW_NUMBER() OVER (ORDER BY sometable.rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30)") + + s = select([t]).limit(10).offset(20).order_by(t.c.col2) + + self.assert_compile(s, "SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, " + "sometable.col2 AS col2, ROW_NUMBER() OVER (ORDER BY sometable.col2) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30") + + def test_outer_join(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String), + column('description', String), + ) + + table2 = table( + 'myothertable', + column('otherid', Integer), + column('othername', String), + ) + + table3 = table( + 'thirdtable', + column('userid', Integer), + column('otherstuff', String), + ) + + query = select( + [table1, table2], + or_( + table1.c.name == 'fred', + table1.c.myid == 10, + table2.c.othername != 'jack', + "EXISTS (select yay from foo where boo = lar)" + ), + from_obj = [ outerjoin(table1, table2, table1.c.myid == table2.c.otherid) ] + ) + self.assert_compile(query, + "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ +FROM mytable, myothertable WHERE \ +(mytable.name = :name_1 OR mytable.myid = :myid_1 OR \ +myothertable.othername != :othername_1 OR EXISTS (select yay from foo where boo = lar)) \ +AND mytable.myid = myothertable.otherid(+)", + dialect=oracle.OracleDialect(use_ansi = False)) + + query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid) + self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON thirdtable.userid = myothertable.otherid") + self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False)) + + query = table1.join(table2, table1.c.myid==table2.c.otherid).join(table3, table3.c.userid==table2.c.otherid) + self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid = myothertable.otherid", dialect=oracle.dialect(use_ansi=False)) + + query = table1.join(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid) + self.assert_compile(query.select().order_by(table1.oid_column).limit(10).offset(5), "SELECT myid, name, description, otherid, othername, userid, \ +otherstuff FROM (SELECT mytable.myid AS myid, mytable.name AS name, \ +mytable.description AS description, myothertable.otherid AS otherid, \ +myothertable.othername AS othername, thirdtable.userid AS userid, \ +thirdtable.otherstuff AS otherstuff, ROW_NUMBER() OVER (ORDER BY mytable.rowid) AS ora_rn \ +FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid(+) = myothertable.otherid) \ +WHERE ora_rn>5 AND ora_rn<=15", dialect=oracle.dialect(use_ansi=False)) + + def test_alias_outer_join(self): + address_types = table('address_types', + column('id'), + column('name'), + ) + addresses = table('addresses', + column('id'), + column('user_id'), + column('address_type_id'), + column('email_address') + ) + at_alias = address_types.alias() + + s = select([at_alias, addresses]).\ + select_from(addresses.outerjoin(at_alias, addresses.c.address_type_id==at_alias.c.id)).\ + where(addresses.c.user_id==7).\ + order_by(addresses.oid_column, address_types.oid_column) + self.assert_compile(s, "SELECT address_types_1.id, address_types_1.name, addresses.id, addresses.user_id, " + "addresses.address_type_id, addresses.email_address FROM addresses LEFT OUTER JOIN address_types address_types_1 " + "ON addresses.address_type_id = address_types_1.id WHERE addresses.user_id = :user_id_1 ORDER BY addresses.rowid, " + "address_types.rowid") + +class SchemaReflectionTest(TestBase, AssertsCompiledSQL): + """instructions: + + 1. create a user 'ed' in the oracle database. + 2. in 'ed', issue the following statements: + create table parent(id integer primary key, data varchar2(50)); + create table child(id integer primary key, data varchar2(50), parent_id integer references parent(id)); + create synonym ptable for parent; + create synonym ctable for child; + grant all on parent to scott; (or to whoever you run the oracle tests as) + grant all on child to scott; (same) + grant all on ptable to scott; + grant all on ctable to scott; + + """ + + __only_on__ = 'oracle' + + def test_reflect_alt_owner_explicit(self): + meta = MetaData(testing.db) + parent = Table('parent', meta, autoload=True, schema='ed') + child = Table('child', meta, autoload=True, schema='ed') + + self.assert_compile(parent.join(child), "ed.parent JOIN ed.child ON ed.parent.id = ed.child.parent_id") + select([parent, child]).select_from(parent.join(child)).execute().fetchall() + + def test_reflect_local_to_remote(self): + testing.db.execute("CREATE TABLE localtable (id INTEGER PRIMARY KEY, parent_id INTEGER REFERENCES ed.parent(id))") + try: + meta = MetaData(testing.db) + lcl = Table('localtable', meta, autoload=True) + parent = meta.tables['ed.parent'] + self.assert_compile(parent.join(lcl), "ed.parent JOIN localtable ON ed.parent.id = localtable.parent_id") + select([parent, lcl]).select_from(parent.join(lcl)).execute().fetchall() + finally: + testing.db.execute("DROP TABLE localtable") + + def test_reflect_alt_owner_implicit(self): + meta = MetaData(testing.db) + parent = Table('parent', meta, autoload=True, schema='ed') + child = Table('child', meta, autoload=True, schema='ed') + + self.assert_compile(parent.join(child), "ed.parent JOIN ed.child ON ed.parent.id = ed.child.parent_id") + select([parent, child]).select_from(parent.join(child)).execute().fetchall() + + def test_reflect_alt_owner_synonyms(self): + testing.db.execute("CREATE TABLE localtable (id INTEGER PRIMARY KEY, parent_id INTEGER REFERENCES ed.ptable(id))") + try: + meta = MetaData(testing.db) + lcl = Table('localtable', meta, autoload=True, oracle_resolve_synonyms=True) + parent = meta.tables['ed.ptable'] + self.assert_compile(parent.join(lcl), "ed.ptable JOIN localtable ON ed.ptable.id = localtable.parent_id") + select([parent, lcl]).select_from(parent.join(lcl)).execute().fetchall() + finally: + testing.db.execute("DROP TABLE localtable") + + def test_reflect_remote_synonyms(self): + meta = MetaData(testing.db) + parent = Table('ptable', meta, autoload=True, schema='ed', oracle_resolve_synonyms=True) + child = Table('ctable', meta, autoload=True, schema='ed', oracle_resolve_synonyms=True) + self.assert_compile(parent.join(child), "ed.ptable JOIN ed.ctable ON ed.ptable.id = ed.ctable.parent_id") + select([parent, child]).select_from(parent.join(child)).execute().fetchall() + + +class TypesTest(TestBase, AssertsCompiledSQL): + __only_on__ = 'oracle' + + def test_no_clobs_for_string_params(self): + """test that simple string params get a DBAPI type of VARCHAR, not CLOB. + this is to prevent setinputsizes from setting up cx_oracle.CLOBs on + string-based bind params [ticket:793].""" + + class FakeDBAPI(object): + def __getattr__(self, attr): + return attr + + dialect = oracle.OracleDialect() + dbapi = FakeDBAPI() + + b = bindparam("foo", "hello world!") + assert b.type.dialect_impl(dialect).get_dbapi_type(dbapi) == 'STRING' + + b = bindparam("foo", u"hello world!") + assert b.type.dialect_impl(dialect).get_dbapi_type(dbapi) == 'STRING' + + def test_reflect_raw(self): + types_table = Table( + 'all_types', MetaData(testing.db), + Column('owner', String(30), primary_key=True), + Column('type_name', String(30), primary_key=True), + autoload=True, oracle_resolve_synonyms=True + ) + [[row[k] for k in row.keys()] for row in types_table.select().execute().fetchall()] + + def test_longstring(self): + metadata = MetaData(testing.db) + testing.db.execute(""" + CREATE TABLE Z_TEST + ( + ID NUMERIC(22) PRIMARY KEY, + ADD_USER VARCHAR2(20) NOT NULL + ) + """) + try: + t = Table("z_test", metadata, autoload=True) + t.insert().execute(id=1.0, add_user='foobar') + assert t.select().execute().fetchall() == [(1, 'foobar')] + finally: + testing.db.execute("DROP TABLE Z_TEST") + +class SequenceTest(TestBase, AssertsCompiledSQL): + def test_basic(self): + seq = Sequence("my_seq_no_schema") + dialect = oracle.OracleDialect() + assert dialect.identifier_preparer.format_sequence(seq) == "my_seq_no_schema" + + seq = Sequence("my_seq", schema="some_schema") + assert dialect.identifier_preparer.format_sequence(seq) == "some_schema.my_seq" + + seq = Sequence("My_Seq", schema="Some_Schema") + assert dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"' + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py index f80ddcadd6..90cc0a4774 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/postgres.py @@ -1,86 +1,457 @@ -import testbase +import testenv; testenv.configure_for_tests() import datetime from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy import exceptions from sqlalchemy.databases import postgres +from sqlalchemy.engine.strategies import MockEngineStrategy from testlib import * +from sqlalchemy.sql import table, column -class DomainReflectionTest(AssertMixin): +class SequenceTest(TestBase, AssertsCompiledSQL): + def test_basic(self): + seq = Sequence("my_seq_no_schema") + dialect = postgres.PGDialect() + assert dialect.identifier_preparer.format_sequence(seq) == "my_seq_no_schema" + + seq = Sequence("my_seq", schema="some_schema") + assert dialect.identifier_preparer.format_sequence(seq) == "some_schema.my_seq" + + seq = Sequence("My_Seq", schema="Some_Schema") + assert dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"' + +class CompileTest(TestBase, AssertsCompiledSQL): + def test_update_returning(self): + dialect = postgres.dialect() + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) + + u = update(table1, values=dict(name='foo'), postgres_returning=[table1]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\ + "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) + + u = update(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect) + + def test_insert_returning(self): + dialect = postgres.dialect() + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect) + + i = insert(table1, values=dict(name='foo'), postgres_returning=[table1]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\ + "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) + + i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) + +class ReturningTest(TestBase, AssertsExecutionResults): + __only_on__ = 'postgres' + + @testing.exclude('postgres', '<', (8, 2)) + def test_update_returning(self): + meta = MetaData(testing.db) + table = Table('tables', meta, + Column('id', Integer, primary_key=True), + Column('persons', Integer), + Column('full', Boolean) + ) + table.create() + try: + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute() + self.assertEqual(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + self.assertEqual(result2.fetchall(), [(1,True),(2,False)]) + finally: + table.drop() + + @testing.exclude('postgres', '<', (8, 2)) + def test_insert_returning(self): + meta = MetaData(testing.db) + table = Table('tables', meta, + Column('id', Integer, primary_key=True), + Column('persons', Integer), + Column('full', Boolean) + ) + table.create() + try: + result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False}) + + self.assertEqual(result.fetchall(), [(1,)]) + + # Multiple inserts only return the last row + result2 = table.insert(postgres_returning=[table]).execute( + [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) + + self.assertEqual(result2.fetchall(), [(3,3,True)]) + + result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False}) + self.assertEqual([dict(row) for row in result3], [{'double_id':8}]) + + result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons') + self.assertEqual([dict(row) for row in result4], [{'persons': 10}]) + finally: + table.drop() + + +class InsertTest(TestBase, AssertsExecutionResults): + __only_on__ = 'postgres' + + def setUpAll(self): + global metadata + metadata = MetaData(testing.db) + + def tearDown(self): + metadata.drop_all() + metadata.tables.clear() + + def test_compiled_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + + metadata.create_all() + + ins = table.insert(values={'data':bindparam('x')}).compile() + ins.execute({'x':"five"}, {'x':"seven"}) + assert table.select().execute().fetchall() == [(1, 'five'), (2, 'seven')] + + def test_sequence_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, Sequence('my_seq'), primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_with_sequence(table, "my_seq") + + def test_opt_sequence_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_autoincrement(table) + + def test_autoincrement_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_autoincrement(table) + + def test_noautoincrement_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('data', String(30))) + metadata.create_all() + self._assert_data_noautoincrement(table) + + def _assert_data_autoincrement(self, table): + def go(): + # execute with explicit id + r = table.insert().execute({'id':30, 'data':'d1'}) + assert r.last_inserted_ids() == [30] + + # execute with prefetch id + r = table.insert().execute({'data':'d2'}) + assert r.last_inserted_ids() == [1] + + # executemany with explicit ids + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + + # executemany, uses SERIAL + table.insert().execute({'data':'d5'}, {'data':'d6'}) + + # single execute, explicit id, inline + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + + # single execute, inline, uses SERIAL + table.insert(inline=True).execute({'data':'d8'}) + + # note that the test framework doesnt capture the "preexecute" of a seqeuence + # or default. we just see it in the bind params. + + self.assert_sql(testing.db, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':1, 'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (1, 'd2'), + (31, 'd3'), + (32, 'd4'), + (2, 'd5'), + (3, 'd6'), + (33, 'd7'), + (4, 'd8'), + ] + table.delete().execute() + + # test the same series of events using a reflected + # version of the table + m2 = MetaData(testing.db) + table = Table(table.name, m2, autoload=True) + + def go(): + table.insert().execute({'id':30, 'data':'d1'}) + r = table.insert().execute({'data':'d2'}) + assert r.last_inserted_ids() == [5] + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + table.insert().execute({'data':'d5'}, {'data':'d6'}) + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(testing.db, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':5, 'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (5, 'd2'), + (31, 'd3'), + (32, 'd4'), + (6, 'd5'), + (7, 'd6'), + (33, 'd7'), + (8, 'd8'), + ] + table.delete().execute() + + def _assert_data_with_sequence(self, table, seqname): + def go(): + table.insert().execute({'id':30, 'data':'d1'}) + table.insert().execute({'data':'d2'}) + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + table.insert().execute({'data':'d5'}, {'data':'d6'}) + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(testing.db, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':1, 'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname, + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname, + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (1, 'd2'), + (31, 'd3'), + (32, 'd4'), + (2, 'd5'), + (3, 'd6'), + (33, 'd7'), + (4, 'd8'), + ] + + # cant test reflection here since the Sequence must be + # explicitly specified + + def _assert_data_noautoincrement(self, table): + table.insert().execute({'id':30, 'data':'d1'}) + try: + table.insert().execute({'data':'d2'}) + assert False + except exceptions.IntegrityError, e: + assert "violates not-null constraint" in str(e) + try: + table.insert().execute({'data':'d2'}, {'data':'d3'}) + assert False + except exceptions.IntegrityError, e: + assert "violates not-null constraint" in str(e) + + table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'}) + table.insert(inline=True).execute({'id':33, 'data':'d4'}) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (31, 'd2'), + (32, 'd3'), + (33, 'd4'), + ] + table.delete().execute() + + # test the same series of events using a reflected + # version of the table + m2 = MetaData(testing.db) + table = Table(table.name, m2, autoload=True) + table.insert().execute({'id':30, 'data':'d1'}) + try: + table.insert().execute({'data':'d2'}) + assert False + except exceptions.IntegrityError, e: + assert "violates not-null constraint" in str(e) + try: + table.insert().execute({'data':'d2'}, {'data':'d3'}) + assert False + except exceptions.IntegrityError, e: + assert "violates not-null constraint" in str(e) + + table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'}) + table.insert(inline=True).execute({'id':33, 'data':'d4'}) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (31, 'd2'), + (32, 'd3'), + (33, 'd4'), + ] + +class DomainReflectionTest(TestBase, AssertsExecutionResults): "Test PostgreSQL domains" - @testing.supported('postgres') + __only_on__ = 'postgres' + def setUpAll(self): - con = testbase.db.connect() - con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42') - con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0') + con = testing.db.connect() + try: + con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42') + con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0') + except exceptions.SQLError, e: + if not "already exists" in str(e): + raise e con.execute('CREATE TABLE testtable (question integer, answer testdomain)') con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)') con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)') - @testing.supported('postgres') def tearDownAll(self): - con = testbase.db.connect() + con = testing.db.connect() con.execute('DROP TABLE testtable') con.execute('DROP TABLE alt_schema.testtable') con.execute('DROP TABLE crosschema') con.execute('DROP DOMAIN testdomain') con.execute('DROP DOMAIN alt_schema.testdomain') - @testing.supported('postgres') def test_table_is_reflected(self): - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns") self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger) - - @testing.supported('postgres') + def test_domain_is_reflected(self): - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) self.assertEquals(str(table.columns.answer.default.arg), '42', "Reflected default value didn't equal expected value") self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.") - @testing.supported('postgres') def test_table_is_reflected_alt_schema(self): - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True, schema='alt_schema') self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns") self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger) - @testing.supported('postgres') def test_schema_domain_is_reflected(self): - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True, schema='alt_schema') self.assertEquals(str(table.columns.answer.default.arg), '0', "Reflected default value didn't equal expected value") self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.") - @testing.supported('postgres') def test_crosschema_domain_is_reflected(self): - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table = Table('crosschema', metadata, autoload=True) self.assertEquals(str(table.columns.answer.default.arg), '0', "Reflected default value didn't equal expected value") self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.") -class MiscTest(AssertMixin): - @testing.supported('postgres') +class MiscTest(TestBase, AssertsExecutionResults): + __only_on__ = 'postgres' + def test_date_reflection(self): - m1 = MetaData(testbase.db) - t1 = Table('pgdate', m1, + m1 = MetaData(testing.db) + t1 = Table('pgdate', m1, Column('date1', DateTime(timezone=True)), Column('date2', DateTime(timezone=False)) ) m1.create_all() try: - m2 = MetaData(testbase.db) + m2 = MetaData(testing.db) t2 = Table('pgdate', m2, autoload=True) assert t2.c.date1.type.timezone is True assert t2.c.date2.type.timezone is False finally: m1.drop_all() - @testing.supported('postgres') def test_pg_weirdchar_reflection(self): - meta1 = MetaData(testbase.db) + meta1 = MetaData(testing.db) subject = Table("subject", meta1, Column("id$", Integer, primary_key=True), ) @@ -91,30 +462,45 @@ class MiscTest(AssertMixin): ) meta1.create_all() try: - meta2 = MetaData(testbase.db) + meta2 = MetaData(testing.db) subject = Table("subject", meta2, autoload=True) referer = Table("referer", meta2, autoload=True) print str(subject.join(referer).onclause) self.assert_((subject.c['id$']==referer.c.ref).compare(subject.join(referer).onclause)) finally: meta1.drop_all() - - @testing.supported('postgres') + def test_checksfor_sequence(self): - meta1 = MetaData(testbase.db) - t = Table('mytable', meta1, + meta1 = MetaData(testing.db) + t = Table('mytable', meta1, Column('col1', Integer, Sequence('fooseq'))) try: - testbase.db.execute("CREATE SEQUENCE fooseq") + testing.db.execute("CREATE SEQUENCE fooseq") t.create(checkfirst=True) finally: t.drop(checkfirst=True) - @testing.supported('postgres') + def test_distinct_on(self): + t = Table('mytable', MetaData(testing.db), + Column('id', Integer, primary_key=True), + Column('a', String(8))) + self.assertEquals( + str(t.select(distinct=t.c.a)), + 'SELECT DISTINCT ON (mytable.a) mytable.id, mytable.a \n' + 'FROM mytable') + self.assertEquals( + str(t.select(distinct=['id','a'])), + 'SELECT DISTINCT ON (id, a) mytable.id, mytable.a \n' + 'FROM mytable') + self.assertEquals( + str(t.select(distinct=[t.c.id, t.c.a])), + 'SELECT DISTINCT ON (mytable.id, mytable.a) mytable.id, mytable.a \n' + 'FROM mytable') + def test_schema_reflection(self): """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user""" - meta1 = MetaData(testbase.db) + meta1 = MetaData(testing.db) users = Table('users', meta1, Column('user_id', Integer, primary_key = True), Column('user_name', String(30), nullable = False), @@ -129,7 +515,7 @@ class MiscTest(AssertMixin): ) meta1.create_all() try: - meta2 = MetaData(testbase.db) + meta2 = MetaData(testing.db) addresses = Table('email_addresses', meta2, autoload=True, schema="alt_schema") users = Table('users', meta2, mustexist=True, schema="alt_schema") @@ -141,9 +527,8 @@ class MiscTest(AssertMixin): finally: meta1.drop_all() - @testing.supported('postgres') def test_schema_reflection_2(self): - meta1 = MetaData(testbase.db) + meta1 = MetaData(testing.db) subject = Table("subject", meta1, Column("id", Integer, primary_key=True), ) @@ -154,17 +539,16 @@ class MiscTest(AssertMixin): schema="alt_schema") meta1.create_all() try: - meta2 = MetaData(testbase.db) + meta2 = MetaData(testing.db) subject = Table("subject", meta2, autoload=True) referer = Table("referer", meta2, schema="alt_schema", autoload=True) print str(subject.join(referer).onclause) self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause)) finally: meta1.drop_all() - - @testing.supported('postgres') + def test_schema_reflection_3(self): - meta1 = MetaData(testbase.db) + meta1 = MetaData(testing.db) subject = Table("subject", meta1, Column("id", Integer, primary_key=True), schema='alt_schema_2' @@ -177,23 +561,46 @@ class MiscTest(AssertMixin): meta1.create_all() try: - meta2 = MetaData(testbase.db) + meta2 = MetaData(testing.db) subject = Table("subject", meta2, autoload=True, schema="alt_schema_2") referer = Table("referer", meta2, schema="alt_schema", autoload=True) print str(subject.join(referer).onclause) self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause)) finally: meta1.drop_all() - - @testing.supported('postgres') + + def test_schema_roundtrips(self): + meta = MetaData(testing.db) + users = Table('users', meta, + Column('id', Integer, primary_key=True), + Column('name', String(50)), schema='alt_schema') + users.create() + try: + users.insert().execute(id=1, name='name1') + users.insert().execute(id=2, name='name2') + users.insert().execute(id=3, name='name3') + users.insert().execute(id=4, name='name4') + + self.assertEquals(users.select().where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')]) + self.assertEquals(users.select(use_labels=True).where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')]) + + users.delete().where(users.c.id==3).execute() + self.assertEquals(users.select().where(users.c.name=='name3').execute().fetchall(), []) + + users.update().where(users.c.name=='name4').execute(name='newname') + self.assertEquals(users.select(use_labels=True).where(users.c.id==4).execute().fetchall(), [(4, 'newname')]) + + finally: + users.drop() + def test_preexecute_passivedefault(self): - """test that when we get a primary key column back + """test that when we get a primary key column back from reflecting a table which has a default value on it, we pre-execute that PassiveDefault upon insert.""" - + try: - meta = MetaData(testbase.db) - testbase.db.execute(""" + meta = MetaData(testing.db) + testing.db.execute(""" CREATE TABLE speedy_users ( speedy_user_id SERIAL PRIMARY KEY, @@ -209,17 +616,34 @@ class MiscTest(AssertMixin): l = t.select().execute().fetchall() assert l == [(1, 'user', 'lala')] finally: - testbase.db.execute("drop table speedy_users", None) - -class TimezoneTest(AssertMixin): - """test timezone-aware datetimes. psycopg will return a datetime with a tzinfo attached to it, - if postgres returns it. python then will not let you compare a datetime with a tzinfo to a datetime - that doesnt have one. this test illustrates two ways to have datetime types with and without timezone - info. """ - @testing.supported('postgres') + testing.db.execute("drop table speedy_users", None) + + def test_create_partial_index(self): + tbl = Table('testtbl', MetaData(), Column('data',Integer)) + idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + + executed_sql = [] + mock_strategy = MockEngineStrategy() + mock_conn = mock_strategy.create('postgres://', executed_sql.append) + + idx.create(mock_conn) + + assert executed_sql == ['CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10'] + +class TimezoneTest(TestBase, AssertsExecutionResults): + """Test timezone-aware datetimes. + + psycopg will return a datetime with a tzinfo attached to it, if postgres + returns it. python then will not let you compare a datetime with a tzinfo + to a datetime that doesnt have one. this test illustrates two ways to + have datetime types with and without timezone info. + """ + + __only_on__ = 'postgres' + def setUpAll(self): global tztable, notztable, metadata - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) # current_timestamp() in postgres is assumed to return TIMESTAMP WITH TIMEZONE tztable = Table('tztable', metadata, @@ -233,54 +657,47 @@ class TimezoneTest(AssertMixin): Column("name", String(20)), ) metadata.create_all() - @testing.supported('postgres') def tearDownAll(self): metadata.drop_all() - @testing.supported('postgres') def test_with_timezone(self): # get a date with a tzinfo - somedate = testbase.db.connect().scalar(func.current_timestamp().select()) + somedate = testing.db.connect().scalar(func.current_timestamp().select()) tztable.insert().execute(id=1, name='row1', date=somedate) c = tztable.update(tztable.c.id==1).execute(name='newname') - x = c.last_updated_params() - print x['date'] == somedate + print tztable.select(tztable.c.id==1).execute().fetchone() - @testing.supported('postgres') def test_without_timezone(self): # get a date without a tzinfo somedate = datetime.datetime(2005, 10,20, 11, 52, 00) notztable.insert().execute(id=1, name='row1', date=somedate) c = notztable.update(notztable.c.id==1).execute(name='newname') - x = c.last_updated_params() - print x['date'] == somedate + print notztable.select(tztable.c.id==1).execute().fetchone() + +class ArrayTest(TestBase, AssertsExecutionResults): + __only_on__ = 'postgres' -class ArrayTest(AssertMixin): - @testing.supported('postgres') def setUpAll(self): global metadata, arrtable - metadata = MetaData(testbase.db) - + metadata = MetaData(testing.db) + arrtable = Table('arrtable', metadata, Column('id', Integer, primary_key=True), Column('intarr', postgres.PGArray(Integer)), - Column('strarr', postgres.PGArray(String), nullable=False) + Column('strarr', postgres.PGArray(String(convert_unicode=True)), nullable=False) ) metadata.create_all() - @testing.supported('postgres') def tearDownAll(self): metadata.drop_all() - - @testing.supported('postgres') + def test_reflect_array_column(self): - metadata2 = MetaData(testbase.db) + metadata2 = MetaData(testing.db) tbl = Table('arrtable', metadata2, autoload=True) self.assertTrue(isinstance(tbl.c.intarr.type, postgres.PGArray)) self.assertTrue(isinstance(tbl.c.strarr.type, postgres.PGArray)) self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer)) self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String)) - - @testing.supported('postgres') + def test_insert_array(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = arrtable.select().execute().fetchall() @@ -289,7 +706,6 @@ class ArrayTest(AssertMixin): self.assertEquals(results[0]['strarr'], ['abc','def']) arrtable.delete().execute() - @testing.supported('postgres') def test_array_where(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) arrtable.insert().execute(intarr=[4,5,6], strarr='ABC') @@ -297,8 +713,7 @@ class ArrayTest(AssertMixin): self.assertEquals(len(results), 1) self.assertEquals(results[0]['intarr'], [1,2,3]) arrtable.delete().execute() - - @testing.supported('postgres') + def test_array_concat(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall() @@ -306,5 +721,96 @@ class ArrayTest(AssertMixin): self.assertEquals(results[0][0], [1,2,3,4,5,6]) arrtable.delete().execute() + def test_array_subtype_resultprocessor(self): + arrtable.insert().execute(intarr=[4,5,6], strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6']]) + arrtable.insert().execute(intarr=[1,2,3], strarr=[u'm\xe4\xe4', u'm\xf6\xf6']) + results = arrtable.select(order_by=[arrtable.c.intarr]).execute().fetchall() + self.assertEquals(len(results), 2) + self.assertEquals(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6']) + self.assertEquals(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']]) + arrtable.delete().execute() + + def test_array_mutability(self): + class Foo(object): pass + footable = Table('foo', metadata, + Column('id', Integer, primary_key=True), + Column('intarr', postgres.PGArray(Integer), nullable=True) + ) + mapper(Foo, footable) + metadata.create_all() + sess = create_session() + + foo = Foo() + foo.id = 1 + foo.intarr = [1,2,3] + sess.save(foo) + sess.flush() + sess.clear() + foo = sess.query(Foo).get(1) + self.assertEquals(foo.intarr, [1,2,3]) + + foo.intarr.append(4) + sess.flush() + sess.clear() + foo = sess.query(Foo).get(1) + self.assertEquals(foo.intarr, [1,2,3,4]) + + foo.intarr = [] + sess.flush() + sess.clear() + self.assertEquals(foo.intarr, []) + + foo.intarr = None + sess.flush() + sess.clear() + self.assertEquals(foo.intarr, None) + + # Errors in r4217: + foo = Foo() + foo.id = 2 + sess.save(foo) + sess.flush() + +class TimeStampTest(TestBase, AssertsExecutionResults): + __only_on__ = 'postgres' + def test_timestamp(self): + engine = testing.db + connection = engine.connect() + s = select([func.TIMESTAMP("12/25/07").label("ts")]) + result = connection.execute(s).fetchone() + self.assertEqual(result[0], datetime.datetime(2007, 12, 25, 0, 0)) + +class ServerSideCursorsTest(TestBase, AssertsExecutionResults): + __only_on__ = 'postgres' + + def setUpAll(self): + global ss_engine + ss_engine = engines.testing_engine(options={'server_side_cursors':True}) + + def tearDownAll(self): + ss_engine.dispose() + + def test_roundtrip(self): + test_table = Table('test_table', MetaData(ss_engine), + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + test_table.create(checkfirst=True) + try: + test_table.insert().execute(data='data1') + + nextid = ss_engine.execute(Sequence('test_table_id_seq')) + test_table.insert().execute(id=nextid, data='data2') + + self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')]) + + test_table.update().where(test_table.c.id==2).values(data=test_table.c.data + ' updated').execute() + self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')]) + test_table.delete().execute() + self.assertEquals(test_table.count().scalar(), 0) + finally: + test_table.drop(checkfirst=True) + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/dialect/sqlite.py b/test/dialect/sqlite.py new file mode 100644 index 0000000000..585a853d2a --- /dev/null +++ b/test/dialect/sqlite.py @@ -0,0 +1,303 @@ +"""SQLite-specific tests.""" + +import testenv; testenv.configure_for_tests() +import datetime +from sqlalchemy import * +from sqlalchemy import exceptions +from sqlalchemy.databases import sqlite +from testlib import * + + +class TestTypes(TestBase, AssertsExecutionResults): + __only_on__ = 'sqlite' + + def test_date(self): + meta = MetaData(testing.db) + t = Table('testdate', meta, + Column('id', Integer, primary_key=True), + Column('adate', Date), + Column('adatetime', DateTime)) + meta.create_all() + try: + d1 = datetime.date(2007, 10, 30) + d2 = datetime.datetime(2007, 10, 30) + + t.insert().execute(adate=str(d1), adatetime=str(d2)) + + self.assert_(t.select().execute().fetchall()[0] == + (1, datetime.date(2007, 10, 30), + datetime.datetime(2007, 10, 30))) + + finally: + meta.drop_all() + + @testing.uses_deprecated('Using String type with no length') + def test_type_reflection(self): + # (ask_for, roundtripped_as_if_different) + specs = [( String(), sqlite.SLText(), ), + ( String(1), sqlite.SLString(1), ), + ( String(3), sqlite.SLString(3), ), + ( Text(), sqlite.SLText(), ), + ( Unicode(), sqlite.SLText(), ), + ( Unicode(1), sqlite.SLString(1), ), + ( Unicode(3), sqlite.SLString(3), ), + ( UnicodeText(), sqlite.SLText(), ), + ( CLOB, sqlite.SLText(), ), + ( sqlite.SLChar(1), ), + ( CHAR(3), sqlite.SLChar(3), ), + ( NCHAR(2), sqlite.SLChar(2), ), + ( SmallInteger(), sqlite.SLSmallInteger(), ), + ( sqlite.SLSmallInteger(), ), + ( Binary(3), sqlite.SLBinary(), ), + ( Binary(), sqlite.SLBinary() ), + ( sqlite.SLBinary(3), sqlite.SLBinary(), ), + ( NUMERIC, sqlite.SLNumeric(), ), + ( NUMERIC(10,2), sqlite.SLNumeric(10,2), ), + ( Numeric, sqlite.SLNumeric(), ), + ( Numeric(10, 2), sqlite.SLNumeric(10, 2), ), + ( DECIMAL, sqlite.SLNumeric(), ), + ( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ), + ( Float, sqlite.SLNumeric(), ), + ( sqlite.SLNumeric(), ), + ( INT, sqlite.SLInteger(), ), + ( Integer, sqlite.SLInteger(), ), + ( sqlite.SLInteger(), ), + ( TIMESTAMP, sqlite.SLDateTime(), ), + ( DATETIME, sqlite.SLDateTime(), ), + ( DateTime, sqlite.SLDateTime(), ), + ( sqlite.SLDateTime(), ), + ( DATE, sqlite.SLDate(), ), + ( Date, sqlite.SLDate(), ), + ( sqlite.SLDate(), ), + ( TIME, sqlite.SLTime(), ), + ( Time, sqlite.SLTime(), ), + ( sqlite.SLTime(), ), + ( BOOLEAN, sqlite.SLBoolean(), ), + ( Boolean, sqlite.SLBoolean(), ), + ( sqlite.SLBoolean(), ), + ] + columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)] + + db = testing.db + m = MetaData(db) + t_table = Table('types', m, *columns) + try: + m.create_all() + + m2 = MetaData(db) + rt = Table('types', m2, autoload=True) + try: + db.execute('CREATE VIEW types_v AS SELECT * from types') + rv = Table('types_v', m2, autoload=True) + + expected = [len(c) > 1 and c[1] or c[0] for c in specs] + for table in rt, rv: + for i, reflected in enumerate(table.c): + print reflected.type, type(expected[i]) + assert isinstance(reflected.type, type(expected[i])) + finally: + db.execute('DROP VIEW types_v') + finally: + m.drop_all() + +class DialectTest(TestBase, AssertsExecutionResults): + __only_on__ = 'sqlite' + + def test_extra_reserved_words(self): + """Tests reserved words in identifiers. + + 'true', 'false', and 'column' are undocumented reserved words + when used as column identifiers (as of 3.5.1). Covering them here + to ensure they remain in place if the dialect's reserved_words set + is updated in the future. + """ + + meta = MetaData(testing.db) + t = Table('reserved', meta, + Column('safe', Integer), + Column('true', Integer), + Column('false', Integer), + Column('column', Integer)) + + try: + meta.create_all() + t.insert().execute(safe=1) + list(t.select().execute()) + finally: + meta.drop_all() + + def test_quoted_identifiers(self): + """Tests autoload of tables created with quoted column names.""" + + # This is quirky in sqlite. + testing.db.execute("""CREATE TABLE "django_content_type" ( + "id" integer NOT NULL PRIMARY KEY, + "django_stuff" text NULL + ) + """) + testing.db.execute(""" + CREATE TABLE "django_admin_log" ( + "id" integer NOT NULL PRIMARY KEY, + "action_time" datetime NOT NULL, + "content_type_id" integer NULL REFERENCES "django_content_type" ("id"), + "object_id" text NULL, + "change_message" text NOT NULL + ) + """) + try: + meta = MetaData(testing.db) + table1 = Table("django_admin_log", meta, autoload=True) + table2 = Table("django_content_type", meta, autoload=True) + j = table1.join(table2) + assert j.onclause == table1.c.content_type_id==table2.c.id + finally: + testing.db.execute("drop table django_admin_log") + testing.db.execute("drop table django_content_type") + + + def test_attached_as_schema(self): + cx = testing.db.connect() + try: + cx.execute('ATTACH DATABASE ":memory:" AS alt_schema') + dialect = cx.dialect + assert dialect.table_names(cx, 'alt_schema') == [] + + meta = MetaData(cx) + Table('created', meta, Column('id', Integer), + schema='alt_schema') + alt_master = Table('sqlite_master', meta, autoload=True, + schema='alt_schema') + meta.create_all(cx) + + self.assertEquals(dialect.table_names(cx, 'alt_schema'), + ['created']) + assert len(alt_master.c) > 0 + + meta.clear() + reflected = Table('created', meta, autoload=True, + schema='alt_schema') + assert len(reflected.c) == 1 + + cx.execute(reflected.insert(), dict(id=1)) + r = cx.execute(reflected.select()).fetchall() + assert list(r) == [(1,)] + + cx.execute(reflected.update(), dict(id=2)) + r = cx.execute(reflected.select()).fetchall() + assert list(r) == [(2,)] + + cx.execute(reflected.delete(reflected.c.id==2)) + r = cx.execute(reflected.select()).fetchall() + assert list(r) == [] + + # note that sqlite_master is cleared, above + meta.drop_all() + + assert dialect.table_names(cx, 'alt_schema') == [] + finally: + cx.execute('DETACH DATABASE alt_schema') + + @testing.exclude('sqlite', '<', (2, 6)) + def test_temp_table_reflection(self): + cx = testing.db.connect() + try: + cx.execute('CREATE TEMPORARY TABLE tempy (id INT)') + + assert 'tempy' in cx.dialect.table_names(cx, None) + + meta = MetaData(cx) + tempy = Table('tempy', meta, autoload=True) + assert len(tempy.c) == 1 + meta.drop_all() + except: + try: + cx.execute('DROP TABLE tempy') + except exceptions.DBAPIError: + pass + raise + +class InsertTest(TestBase, AssertsExecutionResults): + """Tests inserts and autoincrement.""" + + __only_on__ = 'sqlite' + + # empty insert (i.e. INSERT INTO table DEFAULT VALUES) + # fails as recently as sqlite 3.3.6. passes on 3.4.1. this syntax + # is nowhere to be found in the sqlite3 documentation or changelog, so can't + # determine what versions in between it's legal for. + def _test_empty_insert(self, table, expect=1): + try: + table.create() + for wanted in (expect, expect * 2): + + table.insert().execute() + + rows = table.select().execute().fetchall() + print rows + self.assertEquals(len(rows), wanted) + finally: + table.drop() + + @testing.exclude('sqlite', '<', (3, 4)) + def test_empty_insert_pk1(self): + self._test_empty_insert( + Table('a', MetaData(testing.db), + Column('id', Integer, primary_key=True))) + + @testing.exclude('sqlite', '<', (3, 4)) + def test_empty_insert_pk2(self): + self.assertRaises( + exceptions.DBAPIError, + self._test_empty_insert, + Table('b', MetaData(testing.db), + Column('x', Integer, primary_key=True), + Column('y', Integer, primary_key=True))) + + @testing.exclude('sqlite', '<', (3, 4)) + def test_empty_insert_pk3(self): + self.assertRaises( + exceptions.DBAPIError, + self._test_empty_insert, + Table('c', MetaData(testing.db), + Column('x', Integer, primary_key=True), + Column('y', Integer, PassiveDefault('123'), + primary_key=True))) + + @testing.exclude('sqlite', '<', (3, 4)) + def test_empty_insert_pk4(self): + self._test_empty_insert( + Table('d', MetaData(testing.db), + Column('x', Integer, primary_key=True), + Column('y', Integer, PassiveDefault('123')))) + + @testing.exclude('sqlite', '<', (3, 4)) + def test_empty_insert_nopk1(self): + self._test_empty_insert( + Table('e', MetaData(testing.db), + Column('id', Integer))) + + @testing.exclude('sqlite', '<', (3, 4)) + def test_empty_insert_nopk2(self): + self._test_empty_insert( + Table('f', MetaData(testing.db), + Column('x', Integer), + Column('y', Integer))) + + def test_inserts_with_spaces(self): + tbl = Table('tbl', MetaData('sqlite:///'), + Column('with space', Integer), + Column('without', Integer)) + tbl.create() + try: + tbl.insert().execute({'without':123}) + assert list(tbl.select().execute()) == [(None, 123)] + + tbl.insert().execute({'with space':456}) + assert list(tbl.select().execute()) == [(None, 123), (456, None)] + + finally: + tbl.drop() + + +if __name__ == "__main__": + testenv.main() diff --git a/test/dialect/sybase.py b/test/dialect/sybase.py new file mode 100644 index 0000000000..19cca465bd --- /dev/null +++ b/test/dialect/sybase.py @@ -0,0 +1,15 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.databases import sybase +from testlib import * + + +class BasicTest(TestBase, AssertsExecutionResults): + # A simple import of the database/ module should work on all systems. + def test_import(self): + # we got this far, right? + return True + + +if __name__ == "__main__": + testenv.main() diff --git a/test/engine/alltests.py b/test/engine/alltests.py index a34a82ed75..75167d5d6f 100644 --- a/test/engine/alltests.py +++ b/test/engine/alltests.py @@ -1,20 +1,21 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest def suite(): modules_to_test = ( # connectivity, execution - 'engine.parseconnect', - 'engine.pool', + 'engine.parseconnect', + 'engine.pool', 'engine.bind', 'engine.reconnect', 'engine.execute', 'engine.metadata', 'engine.transaction', - + # schema/tables - 'engine.reflection', + 'engine.reflection', + 'engine.ddlevents', ) alltests = unittest.TestSuite() @@ -26,6 +27,5 @@ def suite(): return alltests - if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/engine/bind.py b/test/engine/bind.py index 6a0c78f578..b59cd284a1 100644 --- a/test/engine/bind.py +++ b/test/engine/bind.py @@ -1,18 +1,20 @@ """tests the "bind" attribute/argument across schema, SQL, and ORM sessions, including the deprecated versions of these arguments""" -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import engine, exceptions from testlib import * -class BindTest(PersistTest): + +class BindTest(TestBase): def test_create_drop_explicit(self): metadata = MetaData() - table = Table('test_table', metadata, + table = Table('test_table', metadata, Column('foo', Integer)) for bind in ( - testbase.db, - testbase.db.connect() + testing.db, + testing.db.connect() ): for args in [ ([], {'bind':bind}), @@ -24,15 +26,14 @@ class BindTest(PersistTest): table.create(*args[0], **args[1]) table.drop(*args[0], **args[1]) assert not table.exists(*args[0], **args[1]) - + def test_create_drop_err(self): metadata = MetaData() - table = Table('test_table', metadata, + table = Table('test_table', metadata, Column('foo', Integer)) for meth in [ metadata.create_all, - table.exists, metadata.drop_all, table.create, table.drop, @@ -40,18 +41,64 @@ class BindTest(PersistTest): try: meth() assert False - except exceptions.InvalidRequestError, e: - assert str(e) == "This SchemaItem is not connected to any Engine or Connection." - + except exceptions.UnboundExecutionError, e: + self.assertEquals( + str(e), + "The MetaData " + "is not bound to an Engine or Connection. " + "Execution can not proceed without a database to execute " + "against. Either execute with an explicit connection or " + "assign the MetaData's .bind to enable implicit execution.") + + for meth in [ + table.exists, + # future: + #table.create, + #table.drop, + ]: + try: + meth() + assert False + except exceptions.UnboundExecutionError, e: + self.assertEquals( + str(e), + "The Table 'test_table' " + "is not bound to an Engine or Connection. " + "Execution can not proceed without a database to execute " + "against. Either execute with an explicit connection or " + "assign this Table's .metadata.bind to enable implicit " + "execution.") + + @testing.future + def test_create_drop_err2(self): + for meth in [ + table.exists, + table.create, + table.drop, + ]: + try: + meth() + assert False + except exceptions.UnboundExecutionError, e: + self.assertEquals( + str(e), + "The Table 'test_table' " + "is not bound to an Engine or Connection. " + "Execution can not proceed without a database to execute " + "against. Either execute with an explicit connection or " + "assign this Table's .metadata.bind to enable implicit " + "execution.") + + @testing.uses_deprecated('//connect') def test_create_drop_bound(self): - + for meta in (MetaData,ThreadLocalMetaData): for bind in ( - testbase.db, - testbase.db.connect() + testing.db, + testing.db.connect() ): metadata = meta() - table = Table('test_table', metadata, + table = Table('test_table', metadata, Column('foo', Integer)) metadata.bind = bind assert metadata.bind is table.bind is bind @@ -63,10 +110,11 @@ class BindTest(PersistTest): assert not table.exists() metadata = meta() - table = Table('test_table', metadata, + table = Table('test_table', metadata, Column('foo', Integer)) metadata.connect(bind) + assert metadata.bind is table.bind is bind metadata.create_all() assert table.exists() @@ -79,8 +127,8 @@ class BindTest(PersistTest): def test_create_drop_constructor_bound(self): for bind in ( - testbase.db, - testbase.db.connect() + testing.db, + testing.db.connect() ): try: for args in ( @@ -88,7 +136,7 @@ class BindTest(PersistTest): ([], {'bind':bind}), ): metadata = MetaData(*args[0], **args[1]) - table = Table('test_table', metadata, + table = Table('test_table', metadata, Column('foo', Integer)) assert metadata.bind is table.bind is bind metadata.create_all() @@ -103,11 +151,11 @@ class BindTest(PersistTest): def test_implicit_execution(self): metadata = MetaData() - table = Table('test_table', metadata, + table = Table('test_table', metadata, Column('foo', Integer), test_needs_acid=True, ) - conn = testbase.db.connect() + conn = testing.db.connect() metadata.create_all(bind=conn) try: trans = conn.begin() @@ -122,13 +170,13 @@ class BindTest(PersistTest): assert conn.execute("select count(1) from test_table").scalar() == 0 finally: metadata.drop_all(bind=conn) - + def test_clauseelement(self): metadata = MetaData() - table = Table('test_table', metadata, + table = Table('test_table', metadata, Column('foo', Integer)) - metadata.create_all(bind=testbase.db) + metadata.create_all(bind=testing.db) try: for elem in [ table.select, @@ -137,8 +185,8 @@ class BindTest(PersistTest): lambda **kwargs:text("select * from test_table", **kwargs) ]: for bind in ( - testbase.db, - testbase.db.connect() + testing.db, + testing.db.connect() ): try: e = elem(bind=bind) @@ -153,27 +201,31 @@ class BindTest(PersistTest): assert e.bind is None e.execute() assert False - except exceptions.InvalidRequestError, e: - assert str(e) == "This Compiled object is not bound to any Engine or Connection." - + except exceptions.UnboundExecutionError, e: + assert str(e).endswith( + 'is not bound and does not support direct ' + 'execution. Supply this statement to a Connection or ' + 'Engine for execution. Or, assign a bind to the ' + 'statement or the Metadata of its underlying tables to ' + 'enable implicit execution via this method.') finally: if isinstance(bind, engine.Connection): bind.close() - metadata.drop_all(bind=testbase.db) - + metadata.drop_all(bind=testing.db) + def test_session(self): from sqlalchemy.orm import create_session, mapper metadata = MetaData() - table = Table('test_table', metadata, - Column('foo', Integer, primary_key=True), + table = Table('test_table', metadata, + Column('foo', Integer, Sequence('foo_seq', optional=True), primary_key=True), Column('data', String(30))) class Foo(object): pass mapper(Foo, table) - metadata.create_all(bind=testbase.db) + metadata.create_all(bind=testing.db) try: - for bind in (testbase.db, - testbase.db.connect() + for bind in (testing.db, + testing.db.connect() ): try: for args in ({'bind':bind},): @@ -189,7 +241,7 @@ class BindTest(PersistTest): if isinstance(bind, engine.Connection): bind.close() - + sess = create_session() f = Foo() sess.save(f) @@ -201,8 +253,8 @@ class BindTest(PersistTest): finally: if isinstance(bind, engine.Connection): bind.close() - metadata.drop_all(bind=testbase.db) - - + metadata.drop_all(bind=testing.db) + + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/engine/ddlevents.py b/test/engine/ddlevents.py new file mode 100644 index 0000000000..258c614120 --- /dev/null +++ b/test/engine/ddlevents.py @@ -0,0 +1,362 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy import exceptions +from sqlalchemy.schema import DDL +import sqlalchemy +from testlib import * + + +class DDLEventTest(TestBase): + class Canary(object): + def __init__(self, schema_item, bind): + self.state = None + self.schema_item = schema_item + self.bind = bind + + def before_create(self, action, schema_item, bind): + assert self.state is None + assert schema_item is self.schema_item + assert bind is self.bind + self.state = action + + def after_create(self, action, schema_item, bind): + assert self.state in ('before-create', 'skipped') + assert schema_item is self.schema_item + assert bind is self.bind + self.state = action + + def before_drop(self, action, schema_item, bind): + assert self.state is None + assert schema_item is self.schema_item + assert bind is self.bind + self.state = action + + def after_drop(self, action, schema_item, bind): + assert self.state in ('before-drop', 'skipped') + assert schema_item is self.schema_item + assert bind is self.bind + self.state = action + + def mock_engine(self): + buffer = [] + def executor(sql, *a, **kw): + buffer.append(sql) + engine = create_engine(testing.db.name + '://', + strategy='mock', executor=executor) + assert not hasattr(engine, 'mock') + engine.mock = buffer + return engine + + def setUp(self): + self.bind = self.mock_engine() + self.metadata = MetaData() + self.table = Table('t', self.metadata, Column('id', Integer)) + + def test_table_create_before(self): + table, bind = self.table, self.bind + canary = self.Canary(table, bind) + table.ddl_listeners['before-create'].append(canary.before_create) + + table.create(bind) + assert canary.state == 'before-create' + table.drop(bind) + assert canary.state == 'before-create' + + def test_table_create_after(self): + table, bind = self.table, self.bind + canary = self.Canary(table, bind) + table.ddl_listeners['after-create'].append(canary.after_create) + + canary.state = 'skipped' + table.create(bind) + assert canary.state == 'after-create' + table.drop(bind) + assert canary.state == 'after-create' + + def test_table_create_both(self): + table, bind = self.table, self.bind + canary = self.Canary(table, bind) + table.ddl_listeners['before-create'].append(canary.before_create) + table.ddl_listeners['after-create'].append(canary.after_create) + + table.create(bind) + assert canary.state == 'after-create' + table.drop(bind) + assert canary.state == 'after-create' + + def test_table_drop_before(self): + table, bind = self.table, self.bind + canary = self.Canary(table, bind) + table.ddl_listeners['before-drop'].append(canary.before_drop) + + table.create(bind) + assert canary.state is None + table.drop(bind) + assert canary.state == 'before-drop' + + def test_table_drop_after(self): + table, bind = self.table, self.bind + canary = self.Canary(table, bind) + table.ddl_listeners['after-drop'].append(canary.after_drop) + + table.create(bind) + assert canary.state is None + canary.state = 'skipped' + table.drop(bind) + assert canary.state == 'after-drop' + + def test_table_drop_both(self): + table, bind = self.table, self.bind + canary = self.Canary(table, bind) + table.ddl_listeners['before-drop'].append(canary.before_drop) + table.ddl_listeners['after-drop'].append(canary.after_drop) + + table.create(bind) + assert canary.state is None + table.drop(bind) + assert canary.state == 'after-drop' + + def test_table_all(self): + table, bind = self.table, self.bind + canary = self.Canary(table, bind) + table.ddl_listeners['before-create'].append(canary.before_create) + table.ddl_listeners['after-create'].append(canary.after_create) + table.ddl_listeners['before-drop'].append(canary.before_drop) + table.ddl_listeners['after-drop'].append(canary.after_drop) + + assert canary.state is None + table.create(bind) + assert canary.state == 'after-create' + canary.state = None + table.drop(bind) + assert canary.state == 'after-drop' + + def test_table_create_before(self): + metadata, bind = self.metadata, self.bind + canary = self.Canary(metadata, bind) + metadata.ddl_listeners['before-create'].append(canary.before_create) + + metadata.create_all(bind) + assert canary.state == 'before-create' + metadata.drop_all(bind) + assert canary.state == 'before-create' + + def test_metadata_create_after(self): + metadata, bind = self.metadata, self.bind + canary = self.Canary(metadata, bind) + metadata.ddl_listeners['after-create'].append(canary.after_create) + + canary.state = 'skipped' + metadata.create_all(bind) + assert canary.state == 'after-create' + metadata.drop_all(bind) + assert canary.state == 'after-create' + + def test_metadata_create_both(self): + metadata, bind = self.metadata, self.bind + canary = self.Canary(metadata, bind) + metadata.ddl_listeners['before-create'].append(canary.before_create) + metadata.ddl_listeners['after-create'].append(canary.after_create) + + metadata.create_all(bind) + assert canary.state == 'after-create' + metadata.drop_all(bind) + assert canary.state == 'after-create' + + @testing.future + def test_metadata_table_isolation(self): + metadata, table, bind = self.metadata, self.table, self.bind + + table_canary = self.Canary(table, bind) + table.ddl_listeners['before-create'].append(table_canary.before_create) + + metadata_canary = self.Canary(metadata, bind) + metadata.ddl_listeners['before-create'].append(metadata_canary.before_create) + + # currently, table.create() routes through the same execution + # path that metadata.create_all() does + self.table.create(self.bind) + assert metadata_canary.state == None + + def test_append_listener(self): + metadata, table, bind = self.metadata, self.table, self.bind + + fn = lambda *a: None + + table.append_ddl_listener('before-create', fn) + self.assertRaises(LookupError, table.append_ddl_listener, 'blah', fn) + + metadata.append_ddl_listener('before-create', fn) + self.assertRaises(LookupError, metadata.append_ddl_listener, 'blah', fn) + + +class DDLExecutionTest(TestBase): + def mock_engine(self): + buffer = [] + def executor(sql, *a, **kw): + buffer.append(sql) + engine = create_engine(testing.db.name + '://', + strategy='mock', executor=executor) + assert not hasattr(engine, 'mock') + engine.mock = buffer + return engine + + def setUp(self): + self.engine = self.mock_engine() + self.metadata = MetaData(self.engine) + self.users = Table('users', self.metadata, + Column('user_id', Integer, primary_key=True), + Column('user_name', String(40)), + ) + + def test_table_standalone(self): + users, engine = self.users, self.engine + DDL('mxyzptlk').execute_at('before-create', users) + DDL('klptzyxm').execute_at('after-create', users) + DDL('xyzzy').execute_at('before-drop', users) + DDL('fnord').execute_at('after-drop', users) + + users.create() + strings = [str(x) for x in engine.mock] + assert 'mxyzptlk' in strings + assert 'klptzyxm' in strings + assert 'xyzzy' not in strings + assert 'fnord' not in strings + del engine.mock[:] + users.drop() + strings = [str(x) for x in engine.mock] + assert 'mxyzptlk' not in strings + assert 'klptzyxm' not in strings + assert 'xyzzy' in strings + assert 'fnord' in strings + + def test_table_by_metadata(self): + metadata, users, engine = self.metadata, self.users, self.engine + DDL('mxyzptlk').execute_at('before-create', users) + DDL('klptzyxm').execute_at('after-create', users) + DDL('xyzzy').execute_at('before-drop', users) + DDL('fnord').execute_at('after-drop', users) + + metadata.create_all() + strings = [str(x) for x in engine.mock] + assert 'mxyzptlk' in strings + assert 'klptzyxm' in strings + assert 'xyzzy' not in strings + assert 'fnord' not in strings + del engine.mock[:] + metadata.drop_all() + strings = [str(x) for x in engine.mock] + assert 'mxyzptlk' not in strings + assert 'klptzyxm' not in strings + assert 'xyzzy' in strings + assert 'fnord' in strings + + def test_metadata(self): + metadata, engine = self.metadata, self.engine + DDL('mxyzptlk').execute_at('before-create', metadata) + DDL('klptzyxm').execute_at('after-create', metadata) + DDL('xyzzy').execute_at('before-drop', metadata) + DDL('fnord').execute_at('after-drop', metadata) + + metadata.create_all() + strings = [str(x) for x in engine.mock] + assert 'mxyzptlk' in strings + assert 'klptzyxm' in strings + assert 'xyzzy' not in strings + assert 'fnord' not in strings + del engine.mock[:] + metadata.drop_all() + strings = [str(x) for x in engine.mock] + assert 'mxyzptlk' not in strings + assert 'klptzyxm' not in strings + assert 'xyzzy' in strings + assert 'fnord' in strings + + def test_ddl_execute(self): + engine = create_engine('sqlite:///') + cx = engine.connect() + table = self.users + ddl = DDL('SELECT 1') + + for py in ('engine.execute(ddl)', + 'engine.execute(ddl, table)', + 'cx.execute(ddl)', + 'cx.execute(ddl, table)', + 'ddl.execute(engine)', + 'ddl.execute(engine, table)', + 'ddl.execute(cx)', + 'ddl.execute(cx, table)'): + r = eval(py) + assert list(r) == [(1,)], py + + for py in ('ddl.execute()', + 'ddl.execute(schema_item=table)'): + try: + r = eval(py) + assert False + except exceptions.UnboundExecutionError: + pass + + for bind in engine, cx: + ddl.bind = bind + for py in ('ddl.execute()', + 'ddl.execute(schema_item=table)'): + r = eval(py) + assert list(r) == [(1,)], py + +class DDLTest(TestBase): + def mock_engine(self): + executor = lambda *a, **kw: None + engine = create_engine(testing.db.name + '://', + strategy='mock', executor=executor) + engine.dialect.identifier_preparer = \ + sqlalchemy.sql.compiler.IdentifierPreparer(engine.dialect) + return engine + + def test_tokens(self): + m = MetaData() + bind = self.mock_engine() + sane_alone = Table('t', m, Column('id', Integer)) + sane_schema = Table('t', m, Column('id', Integer), schema='s') + insane_alone = Table('t t', m, Column('id', Integer)) + insane_schema = Table('t t', m, Column('id', Integer), schema='s s') + + ddl = DDL('%(schema)s-%(table)s-%(fullname)s') + + self.assertEquals(ddl._expand(sane_alone, bind), '-t-t') + self.assertEquals(ddl._expand(sane_schema, bind), '"s"-t-s.t') + self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"') + self.assertEquals(ddl._expand(insane_schema, bind), + '"s s"-"t t"-"s s"."t t"') + + # overrides are used piece-meal and verbatim. + ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s', + context={'schema':'S S', 'table': 'T T', 'bonus': 'b'}) + self.assertEquals(ddl._expand(sane_alone, bind), 'S S-T T-t-b') + self.assertEquals(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b') + self.assertEquals(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b') + self.assertEquals(ddl._expand(insane_schema, bind), + 'S S-T T-"s s"."t t"-b') + def test_filter(self): + cx = self.mock_engine() + cx.name = 'mock' + + tbl = Table('t', MetaData(), Column('id', Integer)) + + assert DDL('')._should_execute('x', tbl, cx) + assert DDL('', on='mock')._should_execute('x', tbl, cx) + assert not DDL('', on='bogus')._should_execute('x', tbl, cx) + assert DDL('', on=lambda x,y,z: True)._should_execute('x', tbl, cx) + assert(DDL('', on=lambda x,y,z: z.engine.name != 'bogus'). + _should_execute('x', tbl, cx)) + + def test_repr(self): + assert repr(DDL('s')) + assert repr(DDL('s', on='engine')) + assert repr(DDL('s', on=lambda x: 1)) + assert repr(DDL('s', context={'a':1})) + assert repr(DDL('s', on='engine', context={'a':1})) + + +if __name__ == "__main__": + testenv.main() diff --git a/test/engine/execute.py b/test/engine/execute.py index 3d3b43f9b6..260a05e270 100644 --- a/test/engine/execute.py +++ b/test/engine/execute.py @@ -1,66 +1,77 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions from testlib import * -class ExecuteTest(PersistTest): +class ExecuteTest(TestBase): def setUpAll(self): global users, metadata - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) users = Table('users', metadata, Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), ) metadata.create_all() - + def tearDown(self): - testbase.db.connect().execute(users.delete()) + testing.db.connect().execute(users.delete()) def tearDownAll(self): metadata.drop_all() - - @testing.supported('sqlite') + + @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite') def test_raw_qmark(self): - for conn in (testbase.db, testbase.db.connect()): + for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) conn.execute("insert into users (user_id, user_name) values (?, ?)", [2,"fred"]) conn.execute("insert into users (user_id, user_name) values (?, ?)", [3,"ed"], [4,"horse"]) conn.execute("insert into users (user_id, user_name) values (?, ?)", (5,"barney"), (6,"donkey")) conn.execute("insert into users (user_id, user_name) values (?, ?)", 7, 'sally') - res = conn.execute("select * from users") + res = conn.execute("select * from users order by user_id") assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')] conn.execute("delete from users") - @testing.supported('mysql', 'postgres') + @testing.fails_on_everything_except('mysql', 'postgres') + # some psycopg2 versions bomb this. def test_raw_sprintf(self): - for conn in (testbase.db, testbase.db.connect()): + for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (%s, %s)", [1,"jack"]) conn.execute("insert into users (user_id, user_name) values (%s, %s)", [2,"ed"], [3,"horse"]) conn.execute("insert into users (user_id, user_name) values (%s, %s)", 4, 'sally') conn.execute("insert into users (user_id) values (%s)", 5) - res = conn.execute("select * from users") + res = conn.execute("select * from users order by user_id") assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally'), (5, None)] conn.execute("delete from users") # pyformat is supported for mysql, but skipping because a few driver # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2) - @testing.supported('postgres') + @testing.unsupported('mysql') + @testing.fails_on_everything_except('postgres') def test_raw_python(self): - for conn in (testbase.db, testbase.db.connect()): + for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'}) conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'}) conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", id=4, name='sally') - res = conn.execute("select * from users") + res = conn.execute("select * from users order by user_id") assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')] conn.execute("delete from users") - @testing.supported('sqlite') + @testing.fails_on_everything_except('sqlite', 'oracle') def test_raw_named(self): - for conn in (testbase.db, testbase.db.connect()): + for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'}) conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'}) conn.execute("insert into users (user_id, user_name) values (:id, :name)", id=4, name='sally') - res = conn.execute("select * from users") + res = conn.execute("select * from users order by user_id") assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')] conn.execute("delete from users") - + + def test_exception_wrapping(self): + for conn in (testing.db, testing.db.connect()): + try: + conn.execute("osdjafioajwoejoasfjdoifjowejfoawejqoijwef") + assert False + except exceptions.DBAPIError: + assert True + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/engine/metadata.py b/test/engine/metadata.py index 973007fab8..22cdaafee4 100644 --- a/test/engine/metadata.py +++ b/test/engine/metadata.py @@ -1,18 +1,117 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions from testlib import * +import pickle -class MetaDataTest(PersistTest): +class MetaDataTest(TestBase, ComparesTables): def test_metadata_connect(self): metadata = MetaData() t1 = Table('table1', metadata, Column('col1', Integer, primary_key=True), Column('col2', String(20))) - metadata.bind = testbase.db + metadata.bind = testing.db metadata.create_all() try: assert t1.count().scalar() == 0 finally: metadata.drop_all() - + + + def test_dupe_tables(self): + metadata = MetaData() + t1 = Table('table1', metadata, Column('col1', Integer, primary_key=True), + Column('col2', String(20))) + + metadata.bind = testing.db + metadata.create_all() + try: + try: + t1 = Table('table1', metadata, autoload=True) + t2 = Table('table1', metadata, Column('col1', Integer, primary_key=True), + Column('col2', String(20))) + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Table 'table1' is already defined for this MetaData instance. Specify 'useexisting=True' to redefine options and columns on an existing Table object." + finally: + metadata.drop_all() + + @testing.exclude('mysql', '<', (4, 1, 1)) + def test_to_metadata(self): + meta = MetaData() + + table = Table('mytable', meta, + Column('myid', Integer, primary_key=True), + Column('name', String(40), nullable=True), + Column('description', String(30), CheckConstraint("description='hi'")), + UniqueConstraint('name'), + test_needs_fk=True, + ) + + table2 = Table('othertable', meta, + Column('id', Integer, primary_key=True), + Column('myid', Integer, ForeignKey('mytable.myid')), + test_needs_fk=True, + ) + + def test_to_metadata(): + meta2 = MetaData() + table_c = table.tometadata(meta2) + table2_c = table2.tometadata(meta2) + return (table_c, table2_c) + + def test_pickle(): + meta.bind = testing.db + meta2 = pickle.loads(pickle.dumps(meta)) + assert meta2.bind is None + meta3 = pickle.loads(pickle.dumps(meta2)) + return (meta2.tables['mytable'], meta2.tables['othertable']) + + def test_pickle_via_reflect(): + # this is the most common use case, pickling the results of a + # database reflection + meta2 = MetaData(bind=testing.db) + t1 = Table('mytable', meta2, autoload=True) + t2 = Table('othertable', meta2, autoload=True) + meta3 = pickle.loads(pickle.dumps(meta2)) + assert meta3.bind is None + assert meta3.tables['mytable'] is not t1 + return (meta3.tables['mytable'], meta3.tables['othertable']) + + meta.create_all(testing.db) + try: + for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)): + table_c, table2_c = test() + self.assert_tables_equal(table, table_c) + self.assert_tables_equal(table2, table2_c) + + assert table is not table_c + assert table.primary_key is not table_c.primary_key + assert list(table2_c.c.myid.foreign_keys)[0].column is table_c.c.myid + assert list(table2_c.c.myid.foreign_keys)[0].column is not table.c.myid + + # constraints dont get reflected for any dialect right now + if has_constraints: + for c in table_c.c.description.constraints: + if isinstance(c, CheckConstraint): + break + else: + assert False + assert c.sqltext=="description='hi'" + + for c in table_c.constraints: + if isinstance(c, UniqueConstraint): + break + else: + assert False + assert c.columns.contains_column(table_c.c.name) + assert not c.columns.contains_column(table.c.name) + finally: + meta.drop_all(testing.db) + + def test_nonexistent(self): + self.assertRaises(exceptions.NoSuchTableError, Table, + 'fake_table', + MetaData(testing.db), autoload=True) + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/engine/parseconnect.py b/test/engine/parseconnect.py index 3e186275d5..117c3ed4bb 100644 --- a/test/engine/parseconnect.py +++ b/test/engine/parseconnect.py @@ -1,11 +1,13 @@ -import testbase +import testenv; testenv.configure_for_tests() +import ConfigParser, StringIO from sqlalchemy import * +from sqlalchemy import exceptions, pool, engine import sqlalchemy.engine.url as url from testlib import * - -class ParseConnectTest(PersistTest): - def testrfc1738(self): + +class ParseConnectTest(TestBase): + def test_rfc1738(self): for text in ( 'dbtype://username:password@hostspec:110//usr/db_file.db', 'dbtype://username:password@hostspec/database', @@ -35,67 +37,125 @@ class ParseConnectTest(PersistTest): assert u.host == 'hostspec' or u.host == '127.0.0.1' or (not u.host) assert str(u) == text -class CreateEngineTest(PersistTest): +class CreateEngineTest(TestBase): """test that create_engine arguments of different types get propigated properly""" - def testconnectquery(self): + def test_connect_query(self): dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue') - + # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg e = create_engine('postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', module=dbapi) c = e.connect() - def testkwargs(self): + def test_kwargs(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg e = create_engine('postgres://scott:tiger@somehost/test?fooz=somevalue', connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, module=dbapi) c = e.connect() - def testcustom(self): + def test_coerce_config(self): + raw = r""" +[prefixed] +sqlalchemy.url=postgres://scott:tiger@somehost/test?fooz=somevalue +sqlalchemy.convert_unicode=0 +sqlalchemy.echo=false +sqlalchemy.echo_pool=1 +sqlalchemy.max_overflow=2 +sqlalchemy.pool_recycle=50 +sqlalchemy.pool_size=2 +sqlalchemy.pool_threadlocal=1 +sqlalchemy.pool_timeout=10 +[plain] +url=postgres://scott:tiger@somehost/test?fooz=somevalue +convert_unicode=0 +echo=0 +echo_pool=1 +max_overflow=2 +pool_recycle=50 +pool_size=2 +pool_threadlocal=1 +pool_timeout=10 +""" + ini = ConfigParser.ConfigParser() + ini.readfp(StringIO.StringIO(raw)) + + expected = { + 'url': 'postgres://scott:tiger@somehost/test?fooz=somevalue', + 'convert_unicode': 0, + 'echo': False, + 'echo_pool': True, + 'max_overflow': 2, + 'pool_recycle': 50, + 'pool_size': 2, + 'pool_threadlocal': True, + 'pool_timeout': 10, + } + + prefixed = dict(ini.items('prefixed')) + self.assert_(engine._coerce_config(prefixed, 'sqlalchemy.') == expected) + + plain = dict(ini.items('plain')) + self.assert_(engine._coerce_config(plain, '') == expected) + + def test_engine_from_config(self): + dbapi = MockDBAPI() + + config = { + 'sqlalchemy.url':'postgres://scott:tiger@somehost/test?fooz=somevalue', + 'sqlalchemy.pool_recycle':'50', + 'sqlalchemy.echo':'true' + } + + e = engine_from_config(config, module=dbapi) + assert e.pool._recycle == 50 + assert e.url == url.make_url('postgres://scott:tiger@somehost/test?fooz=somevalue') + assert e.echo is True + + def test_custom(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') def connect(): return dbapi.connect(foober=12, lala=18, fooz='somevalue', hoho={'this':'dict'}) - + # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg e = create_engine('postgres://', creator=connect, module=dbapi) c = e.connect() - - def testrecycle(self): + + def test_recycle(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') e = create_engine('postgres://', pool_recycle=472, module=dbapi) assert e.pool._recycle == 472 - - def testbadargs(self): + + def test_badargs(self): # good arg, use MockDBAPI to prevent oracle import errors e = create_engine('oracle://', use_ansi=True, module=MockDBAPI()) - + try: e = create_engine("foobar://", module=MockDBAPI()) assert False except ImportError: - assert True - + assert True + # bad arg try: e = create_engine('postgres://', use_ansi=True, module=MockDBAPI()) assert False except TypeError: assert True - + # bad arg try: e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI()) assert False except TypeError: assert True - + try: e = create_engine('postgres://', lala=5, module=MockDBAPI()) assert False except TypeError: assert True - + try: e = create_engine('sqlite://', lala=5) assert False @@ -114,31 +174,30 @@ class CreateEngineTest(PersistTest): assert False except TypeError: assert True - + e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True) - + e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) try: c = e.connect() assert False except exceptions.DBAPIError: assert True - - def testurlattr(self): + + def test_urlattr(self): """test the url attribute on ``Engine``.""" - + e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI()) u = url.make_url('mysql://scott:tiger@localhost/test') e2 = create_engine(u, module=MockDBAPI()) assert e.url.drivername == e2.url.drivername == 'mysql' assert e.url.username == e2.url.username == 'scott' assert e2.url is u - - def testpoolargs(self): + + def test_poolargs(self): """test that connection pool args make it thru""" - e = create_engine('postgres://', creator=None, pool_recycle=-1, echo_pool=None, auto_close_cursors=False, disallow_open_cursors=True, module=MockDBAPI()) - assert e.pool.auto_close_cursors is False - assert e.pool.disallow_open_cursors is True + e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI()) + assert e.pool._recycle == 50 # these args work for QueuePool e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=pool.QueuePool, module=MockDBAPI()) @@ -169,7 +228,6 @@ class MockCursor(object): def close(self): pass mock_dbapi = MockDBAPI() - + if __name__ == "__main__": - testbase.main() - + testenv.main() diff --git a/test/engine/pool.py b/test/engine/pool.py index 364afa9d75..75cb08e3c8 100644 --- a/test/engine/pool.py +++ b/test/engine/pool.py @@ -1,6 +1,7 @@ -import testbase -import threading, thread, time +import testenv; testenv.configure_for_tests() +import threading, thread, time, gc import sqlalchemy.pool as pool +import sqlalchemy.interfaces as interfaces import sqlalchemy.exceptions as exceptions from testlib import * @@ -31,19 +32,19 @@ class MockCursor(object): def close(self): pass mock_dbapi = MockDBAPI() - -class PoolTest(PersistTest): - + +class PoolTest(TestBase): + def setUp(self): pool.clear_managers() def testmanager(self): manager = pool.manage(mock_dbapi, use_threadlocal=True) - + connection = manager.connect('foo.db') connection2 = manager.connect('foo.db') connection3 = manager.connect('bar.db') - + print "connection " + repr(connection) self.assert_(connection.cursor() is not None) self.assert_(connection is connection2) @@ -56,10 +57,10 @@ class PoolTest(PersistTest): connection = manager.connect(None) except: pass - + def testnonthreadlocalmanager(self): manager = pool.manage(mock_dbapi, use_threadlocal = False) - + connection = manager.connect('foo.db') connection2 = manager.connect('foo.db') @@ -76,12 +77,12 @@ class PoolTest(PersistTest): def _do_testqueuepool(self, useclose=False): p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = False) - + def status(pool): tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup return tup - + c1 = p.connect() self.assert_(status(p) == (3,0,-2,1)) c2 = p.connect() @@ -110,13 +111,13 @@ class PoolTest(PersistTest): self.assert_(status(p) == (3,3,0,0)) c1 = p.connect() c2 = p.connect() - self.assert_(status(p) == (3, 1, 0, 2)) + self.assert_(status(p) == (3, 1, 0, 2), status(p)) if useclose: c2.close() else: c2 = None self.assert_(status(p) == (3, 2, 0, 1)) - + def test_timeout(self): p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) c1 = p.connect() @@ -130,11 +131,13 @@ class PoolTest(PersistTest): assert int(time.time() - now) == 2 def test_timeout_race(self): - # test a race condition where the initial connecting threads all race to queue.Empty, then block on the mutex. - # each thread consumes a connection as they go in. when the limit is reached, the remaining threads - # go in, and get TimeoutError; even though they never got to wait for the timeout on queue.get(). - # the fix involves checking the timeout again within the mutex, and if so, unlocking and throwing them back to the start - # of do_get() + # test a race condition where the initial connecting threads all race + # to queue.Empty, then block on the mutex. each thread consumes a + # connection as they go in. when the limit is reached, the remaining + # threads go in, and get TimeoutError; even though they never got to + # wait for the timeout on queue.get(). the fix involves checking the + # timeout again within the mutex, and if so, unlocking and throwing + # them back to the start of do_get() p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db', delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) timeouts = [] def checkout(): @@ -147,7 +150,7 @@ class PoolTest(PersistTest): continue time.sleep(4) c1.close() - + threads = [] for i in xrange(10): th = threading.Thread(target=checkout) @@ -155,17 +158,17 @@ class PoolTest(PersistTest): threads.append(th) for th in threads: th.join() - + print timeouts assert len(timeouts) > 0 for t in timeouts: assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) - + def _test_overflow(self, thread_count, max_overflow): def creator(): time.sleep(.05) return mock_dbapi.connect('foo.db') - + p = pool.QueuePool(creator=creator, pool_size=3, timeout=2, max_overflow=max_overflow) @@ -195,7 +198,7 @@ class PoolTest(PersistTest): def test_max_overflow(self): self._test_overflow(40, 5) - + def test_mixed_close(self): p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True) c1 = p.connect() @@ -206,7 +209,20 @@ class PoolTest(PersistTest): assert p.checkedout() == 1 c1 = None assert p.checkedout() == 0 - + + def test_weakref_kaboom(self): + p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + c1.close() + c2 = None + del c1 + del c2 + gc.collect() + assert p.checkedout() == 0 + c3 = p.connect() + assert c3 is not None + def test_trick_the_counter(self): """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an @@ -225,7 +241,7 @@ class PoolTest(PersistTest): def test_recycle(self): p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) - + c1 = p.connect() c_id = id(c1.connection) c1.close() @@ -235,7 +251,7 @@ class PoolTest(PersistTest): time.sleep(4) c3= p.connect() assert id(c3.connection) != c_id - + def test_invalidate(self): dbapi = MockDBAPI() p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) @@ -246,7 +262,7 @@ class PoolTest(PersistTest): assert c1.connection.id == c_id c1.invalidate() c1 = None - + c1 = p.connect() assert c1.connection.id != c_id @@ -257,9 +273,9 @@ class PoolTest(PersistTest): assert p2.size() == 1 assert p2._use_threadlocal is False assert p2._max_overflow == 0 - + def test_reconnect(self): - """tests reconnect operations at the pool level. SA's engine/dialect includes another + """tests reconnect operations at the pool level. SA's engine/dialect includes another layer of reconnect support for 'database was lost' errors.""" dbapi = MockDBAPI() p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) @@ -299,7 +315,14 @@ class PoolTest(PersistTest): assert not con.closed c1.close() assert con.closed - + + def test_threadfairy(self): + p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c1.close() + c2 = p.connect() + assert c2.connection is not None + def testthreadlocal_del(self): self._do_testthreadlocal(useclose=False) @@ -310,7 +333,7 @@ class PoolTest(PersistTest): for p in ( pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True), pool.SingletonThreadPool(creator = lambda: mock_dbapi.connect('foo.db'), use_threadlocal = True) - ): + ): c1 = p.connect() c2 = p.connect() self.assert_(c1 is c2) @@ -327,7 +350,7 @@ class PoolTest(PersistTest): c2.close() else: c2 = None - + if useclose: c1 = p.connect() c2 = p.connect() @@ -338,8 +361,8 @@ class PoolTest(PersistTest): c1.close() c1 = c2 = c3 = None - - # extra tests with QueuePool to insure connections get __del__()ed when dereferenced + + # extra tests with QueuePool to ensure connections get __del__()ed when dereferenced if isinstance(p, pool.QueuePool): self.assert_(p.checkedout() == 0) c1 = p.connect() @@ -355,35 +378,241 @@ class PoolTest(PersistTest): def test_properties(self): dbapi = MockDBAPI() p = pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), - pool_size=1, max_overflow=0) + pool_size=1, max_overflow=0, use_threadlocal=False) c = p.connect() - self.assert_(not c.properties) - self.assert_(c.properties is c._connection_record.properties) + self.assert_(not c.info) + self.assert_(c.info is c._connection_record.info) - c.properties['foo'] = 'bar' + c.info['foo'] = 'bar' c.close() del c c = p.connect() - self.assert_('foo' in c.properties) + self.assert_('foo' in c.info) c.invalidate() c = p.connect() - self.assert_('foo' not in c.properties) + self.assert_('foo' not in c.info) - c.properties['foo2'] = 'bar2' + c.info['foo2'] = 'bar2' c.detach() - self.assert_('foo2' in c.properties) + self.assert_('foo2' in c.info) c2 = p.connect() self.assert_(c.connection is not c2.connection) - self.assert_(not c2.properties) - self.assert_('foo2' in c.properties) - + self.assert_(not c2.info) + self.assert_('foo2' in c.info) + + def test_listeners(self): + dbapi = MockDBAPI() + + class InstrumentingListener(object): + def __init__(self): + if hasattr(self, 'connect'): + self.connect = self.inst_connect + if hasattr(self, 'checkout'): + self.checkout = self.inst_checkout + if hasattr(self, 'checkin'): + self.checkin = self.inst_checkin + self.clear() + def clear(self): + self.connected = [] + self.checked_out = [] + self.checked_in = [] + def assert_total(innerself, conn, cout, cin): + self.assert_(len(innerself.connected) == conn) + self.assert_(len(innerself.checked_out) == cout) + self.assert_(len(innerself.checked_in) == cin) + def assert_in(innerself, item, in_conn, in_cout, in_cin): + self.assert_((item in innerself.connected) == in_conn) + self.assert_((item in innerself.checked_out) == in_cout) + self.assert_((item in innerself.checked_in) == in_cin) + def inst_connect(self, con, record): + print "connect(%s, %s)" % (con, record) + assert con is not None + assert record is not None + self.connected.append(con) + def inst_checkout(self, con, record, proxy): + print "checkout(%s, %s, %s)" % (con, record, proxy) + assert con is not None + assert record is not None + assert proxy is not None + self.checked_out.append(con) + def inst_checkin(self, con, record): + print "checkin(%s, %s)" % (con, record) + # con can be None if invalidated + assert record is not None + self.checked_in.append(con) + class ListenAll(interfaces.PoolListener, InstrumentingListener): + pass + class ListenConnect(InstrumentingListener): + def connect(self, con, record): + pass + class ListenCheckOut(InstrumentingListener): + def checkout(self, con, record, proxy, num): + pass + class ListenCheckIn(InstrumentingListener): + def checkin(self, con, record): + pass + + def _pool(**kw): + return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), + use_threadlocal=False, **kw) + + def assert_listeners(p, total, conn, cout, cin): + for instance in (p, p.recreate()): + self.assert_(len(instance.listeners) == total) + self.assert_(len(instance._on_connect) == conn) + self.assert_(len(instance._on_checkout) == cout) + self.assert_(len(instance._on_checkin) == cin) + + p = _pool() + assert_listeners(p, 0, 0, 0, 0) + + p.add_listener(ListenAll()) + assert_listeners(p, 1, 1, 1, 1) + + p.add_listener(ListenConnect()) + assert_listeners(p, 2, 2, 1, 1) + + p.add_listener(ListenCheckOut()) + assert_listeners(p, 3, 2, 2, 1) + + p.add_listener(ListenCheckIn()) + assert_listeners(p, 4, 2, 2, 2) + del p + + print "----" + snoop = ListenAll() + p = _pool(listeners=[snoop]) + assert_listeners(p, 1, 1, 1, 1) + + c = p.connect() + snoop.assert_total(1, 1, 0) + cc = c.connection + snoop.assert_in(cc, True, True, False) + c.close() + snoop.assert_in(cc, True, True, True) + del c, cc + + snoop.clear() + + # this one depends on immediate gc + c = p.connect() + cc = c.connection + snoop.assert_in(cc, False, True, False) + snoop.assert_total(0, 1, 0) + del c, cc + snoop.assert_total(0, 1, 1) + + p.dispose() + snoop.clear() + + c = p.connect() + c.close() + c = p.connect() + snoop.assert_total(1, 2, 1) + c.close() + snoop.assert_total(1, 2, 2) + + # invalidation + p.dispose() + snoop.clear() + + c = p.connect() + snoop.assert_total(1, 1, 0) + c.invalidate() + snoop.assert_total(1, 1, 1) + c.close() + snoop.assert_total(1, 1, 1) + del c + snoop.assert_total(1, 1, 1) + c = p.connect() + snoop.assert_total(2, 2, 1) + c.close() + del c + snoop.assert_total(2, 2, 2) + + # detached + p.dispose() + snoop.clear() + + c = p.connect() + snoop.assert_total(1, 1, 0) + c.detach() + snoop.assert_total(1, 1, 0) + c.close() + del c + snoop.assert_total(1, 1, 0) + c = p.connect() + snoop.assert_total(2, 2, 0) + c.close() + del c + snoop.assert_total(2, 2, 1) + + def test_listeners_callables(self): + dbapi = MockDBAPI() + + counts = [0, 0, 0] + def connect(dbapi_con, con_record): + counts[0] += 1 + def checkout(dbapi_con, con_record, con_proxy): + counts[1] += 1 + def checkin(dbapi_con, con_record): + counts[2] += 1 + + i_all = dict(connect=connect, checkout=checkout, checkin=checkin) + i_connect = dict(connect=connect) + i_checkout = dict(checkout=checkout) + i_checkin = dict(checkin=checkin) + + def _pool(**kw): + return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), + use_threadlocal=False, **kw) + + def assert_listeners(p, total, conn, cout, cin): + for instance in (p, p.recreate()): + self.assert_(len(instance.listeners) == total) + self.assert_(len(instance._on_connect) == conn) + self.assert_(len(instance._on_checkout) == cout) + self.assert_(len(instance._on_checkin) == cin) + + p = _pool() + assert_listeners(p, 0, 0, 0, 0) + + p.add_listener(i_all) + assert_listeners(p, 1, 1, 1, 1) + + p.add_listener(i_connect) + assert_listeners(p, 2, 2, 1, 1) + + p.add_listener(i_checkout) + assert_listeners(p, 3, 2, 2, 1) + + p.add_listener(i_checkin) + assert_listeners(p, 4, 2, 2, 2) + del p + + p = _pool(listeners=[i_all]) + assert_listeners(p, 1, 1, 1, 1) + + c = p.connect() + assert counts == [1, 1, 0] + c.close() + assert counts == [1, 1, 1] + + c = p.connect() + assert counts == [1, 2, 1] + p.add_listener(i_checkin) + c.close() + assert counts == [1, 2, 3] + + + def tearDown(self): pool.clear_managers() - - + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py index 7c213695f2..d0d037a340 100644 --- a/test/engine/reconnect.py +++ b/test/engine/reconnect.py @@ -1,6 +1,6 @@ -import testbase +import testenv; testenv.configure_for_tests() import sys, weakref -from sqlalchemy import create_engine, exceptions +from sqlalchemy import create_engine, exceptions, select, MetaData, Table, Column, Integer, String from testlib import * @@ -13,53 +13,60 @@ class MockDBAPI(object): self.connections = weakref.WeakKeyDictionary() def connect(self, *args, **kwargs): return MockConnection(self) - + def shutdown(self): + for c in self.connections: + c.explode[0] = True + Error = MockDisconnect + class MockConnection(object): def __init__(self, dbapi): - self.explode = False dbapi.connections[self] = True + self.explode = [False] def rollback(self): pass def commit(self): pass def cursor(self): - return MockCursor(explode=self.explode) + return MockCursor(self) def close(self): pass - + class MockCursor(object): - def __init__(self, explode): - self.explode = explode + def __init__(self, parent): + self.explode = parent.explode self.description = None def execute(self, *args, **kwargs): - if self.explode: + if self.explode[0]: raise MockDisconnect("Lost the DB connection") else: return def close(self): pass - -class ReconnectTest(PersistTest): - def test_reconnect(self): - """test that an 'is_disconnect' condition will invalidate the connection, and additionally - dispose the previous connection pool and recreate.""" - + +class MockReconnectTest(TestBase): + def setUp(self): + global db, dbapi dbapi = MockDBAPI() - + # create engine using our current dburi db = create_engine('postgres://foo:bar@localhost/test', module=dbapi) - + # monkeypatch disconnect checker db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect) - + + def test_reconnect(self): + """test that an 'is_disconnect' condition will invalidate the connection, and additionally + dispose the previous connection pool and recreate.""" + + pid = id(db.pool) - + # make a connection conn = db.connect() - + # connection works - conn.execute("SELECT 1") - + conn.execute(select([1])) + # create a second connection within the pool, which we'll ensure also goes away conn2 = db.connect() conn2.close() @@ -68,30 +75,240 @@ class ReconnectTest(PersistTest): assert len(dbapi.connections) == 2 # set it to fail - conn.connection.connection.explode = True - + dbapi.shutdown() + try: - # execute should fail - conn.execute("SELECT 1") + conn.execute(select([1])) assert False - except exceptions.SQLAlchemyError, e: + except exceptions.DBAPIError: pass - + # assert was invalidated - assert conn.connection.connection is None - + assert not conn.closed + assert conn.invalidated + # close shouldnt break conn.close() assert id(db.pool) != pid - + # ensure all connections closed (pool was recycled) assert len(dbapi.connections) == 0 - + conn =db.connect() - conn.execute("SELECT 1") + conn.execute(select([1])) conn.close() assert len(dbapi.connections) == 1 + + def test_invalidate_trans(self): + conn = db.connect() + trans = conn.begin() + dbapi.shutdown() + + try: + conn.execute(select([1])) + assert False + except exceptions.DBAPIError: + pass + + # assert was invalidated + assert len(dbapi.connections) == 0 + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute(select([1])) + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + conn.execute(select([1])) + assert not conn.invalidated + + assert len(dbapi.connections) == 1 + + def test_conn_reusable(self): + conn = db.connect() + + conn.execute(select([1])) + + assert len(dbapi.connections) == 1 + + dbapi.shutdown() + + # raises error + try: + conn.execute(select([1])) + assert False + except exceptions.DBAPIError: + pass + + assert not conn.closed + assert conn.invalidated + + # ensure all connections closed (pool was recycled) + assert len(dbapi.connections) == 0 + + # test reconnects + conn.execute(select([1])) + assert not conn.invalidated + assert len(dbapi.connections) == 1 + + +class RealReconnectTest(TestBase): + def setUp(self): + global engine + engine = engines.reconnecting_engine() + + def tearDown(self): + engine.dispose() + + def test_reconnect(self): + conn = engine.connect() + + self.assertEquals(conn.execute(select([1])).scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute(select([1])) + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + + assert conn.invalidated + self.assertEquals(conn.execute(select([1])).scalar(), 1) + assert not conn.invalidated + + # one more time + engine.test_shutdown() + try: + conn.execute(select([1])) + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + assert conn.invalidated + self.assertEquals(conn.execute(select([1])).scalar(), 1) + assert not conn.invalidated + + conn.close() + + def test_close(self): + conn = engine.connect() + self.assertEquals(conn.execute(select([1])).scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute(select([1])) + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + conn.close() + conn = engine.connect() + self.assertEquals(conn.execute(select([1])).scalar(), 1) + + def test_with_transaction(self): + conn = engine.connect() + + trans = conn.begin() + + self.assertEquals(conn.execute(select([1])).scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute(select([1])) + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute(select([1])) + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + assert conn.invalidated + self.assertEquals(conn.execute(select([1])).scalar(), 1) + assert not conn.invalidated + +class InvalidateDuringResultTest(TestBase): + def setUp(self): + global meta, table, engine + engine = engines.reconnecting_engine() + meta = MetaData(engine) + table = Table('sometable', meta, + Column('id', Integer, primary_key=True), + Column('name', String(50))) + meta.create_all() + table.insert().execute( + [{'id':i, 'name':'row %d' % i} for i in range(1, 100)] + ) + + def tearDown(self): + meta.drop_all() + engine.dispose() + + @testing.fails_on('mysql') + def test_invalidate_on_results(self): + conn = engine.connect() + + result = conn.execute("select * from sometable") + for x in xrange(20): + result.fetchone() + + engine.test_shutdown() + try: + result.fetchone() + assert False + except exceptions.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert conn.invalidated if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 00c1276eeb..2ace3306a2 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -1,169 +1,216 @@ -import testbase -import pickle, StringIO - +import testenv; testenv.configure_for_tests() +import StringIO, unicodedata from sqlalchemy import * -import sqlalchemy.ansisql as ansisql -from sqlalchemy.exceptions import NoSuchTableError -import sqlalchemy.databases.mysql as mysql +from sqlalchemy import exceptions +from sqlalchemy import types as sqltypes from testlib import * +from testlib import engines -class ReflectionTest(PersistTest): - def testbasic(self): - use_function_defaults = testbase.db.engine.name == 'postgres' or testbase.db.engine.name == 'oracle' - - use_string_defaults = use_function_defaults or testbase.db.engine.__module__.endswith('sqlite') +class ReflectionTest(TestBase, ComparesTables): - if (testbase.db.engine.name == 'mysql' and - testbase.db.dialect.get_version_info(testbase.db) < (4, 1, 1)): - return + @testing.exclude('mysql', '<', (4, 1, 1)) + def test_basic_reflection(self): + meta = MetaData(testing.db) - if use_function_defaults: - defval = func.current_date() - deftype = Date - else: - defval = "3" - deftype = Integer - - if use_string_defaults: - deftype2 = String - defval2 = "im a default" - #deftype3 = DateTime - # the colon thing isnt working out for PG reflection just yet - #defval3 = '1999-09-09 00:00:00' - deftype3 = Date - defval3 = '1999-09-09' - else: - deftype2, deftype3 = Integer, Integer - defval2, defval3 = "15", "16" - - meta = MetaData(testbase.db) - users = Table('engine_users', meta, - Column('user_id', INT, primary_key = True), - Column('user_name', VARCHAR(20), nullable = False), - Column('test1', CHAR(5), nullable = False), - Column('test2', FLOAT(5), nullable = False), - Column('test3', TEXT), - Column('test4', DECIMAL, nullable = False), - Column('test5', TIMESTAMP), + Column('user_id', INT, primary_key=True), + Column('user_name', VARCHAR(20), nullable=False), + Column('test1', CHAR(5), nullable=False), + Column('test2', Float(5), nullable=False), + Column('test3', Text), + Column('test4', Numeric, nullable = False), + Column('test5', DateTime), Column('parent_user_id', Integer, ForeignKey('engine_users.user_id')), - Column('test6', DateTime, nullable = False), - Column('test7', String), + Column('test6', DateTime, nullable=False), + Column('test7', Text), Column('test8', Binary), - Column('test_passivedefault', deftype, PassiveDefault(defval)), Column('test_passivedefault2', Integer, PassiveDefault("5")), - Column('test_passivedefault3', deftype2, PassiveDefault(defval2)), - Column('test_passivedefault4', deftype3, PassiveDefault(defval3)), Column('test9', Binary(100)), - Column('test_numeric', Numeric(None, None)), + Column('test_numeric', Numeric()), test_needs_fk=True, ) - + addresses = Table('engine_email_addresses', meta, Column('address_id', Integer, primary_key = True), Column('remote_user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(20)), test_needs_fk=True, ) - meta.drop_all() - - users.create() - addresses.create() - - # clear out table registry - meta.clear() + meta.create_all() try: - addresses = Table('engine_email_addresses', meta, autoload = True) - # reference the addresses foreign key col, which will require users to be - # reflected at some point - users = Table('engine_users', meta, autoload = True) - assert users.c.user_id in users.primary_key - assert len(users.primary_key) == 1 - finally: - addresses.drop() - users.drop() - - # a hack to remove the defaults we got reflecting from postgres - # SERIAL columns, since they reference sequences that were just dropped. - # PG 8.1 doesnt want to create them if the underlying sequence doesnt exist - users.c.user_id.default = None - addresses.c.address_id.default = None - - users.create() - addresses.create() - try: - # create a join from the two tables, this ensures that - # theres a foreign key set up - # previously, we couldnt get foreign keys out of mysql. seems like - # we can now as long as we use InnoDB -# if testbase.db.engine.__module__.endswith('mysql'): - # addresses.c.remote_user_id.append_item(ForeignKey('engine_users.user_id')) - print users - print addresses - j = join(users, addresses) - print str(j.onclause) - self.assert_((users.c.user_id==addresses.c.remote_user_id).compare(j.onclause)) + meta2 = MetaData() + reflected_users = Table('engine_users', meta2, autoload=True, autoload_with=testing.db) + reflected_addresses = Table('engine_email_addresses', meta2, autoload=True, autoload_with=testing.db) + self.assert_tables_equal(users, reflected_users) + self.assert_tables_equal(addresses, reflected_addresses) finally: addresses.drop() users.drop() - - def test_autoload_partial(self): - meta = MetaData(testbase.db) - foo = Table('foo', meta, - Column('a', String(30)), - Column('b', String(30)), - Column('c', String(30)), - Column('d', String(30)), - Column('e', String(30)), - Column('f', String(30)), - ) + + def test_include_columns(self): + meta = MetaData(testing.db) + foo = Table('foo', meta, *[Column(n, String(30)) for n in ['a', 'b', 'c', 'd', 'e', 'f']]) meta.create_all() try: - meta2 = MetaData(testbase.db) - foo2 = Table('foo', meta2, autoload=True, include_columns=['b', 'f', 'e']) + meta2 = MetaData(testing.db) + foo = Table('foo', meta2, autoload=True, include_columns=['b', 'f', 'e']) # test that cols come back in original order - assert [c.name for c in foo2.c] == ['b', 'e', 'f'] + self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f']) for c in ('b', 'f', 'e'): - assert c in foo2.c + assert c in foo.c for c in ('a', 'c', 'd'): - assert c not in foo2.c + assert c not in foo.c + + # test against a table which is already reflected + meta3 = MetaData(testing.db) + foo = Table('foo', meta3, autoload=True) + foo = Table('foo', meta3, include_columns=['b', 'f', 'e'], useexisting=True) + self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f']) + for c in ('b', 'f', 'e'): + assert c in foo.c + for c in ('a', 'c', 'd'): + assert c not in foo.c finally: meta.drop_all() - - def testoverridecolumns(self): - """test that you can override columns which contain foreign keys to other reflected tables""" - meta = MetaData(testbase.db) - users = Table('users', meta, + + + def test_unknown_types(self): + meta = MetaData(testing.db) + t = Table("test", meta, + Column('foo', DateTime)) + + import sys + dialect_module = sys.modules[testing.db.dialect.__module__] + + # we're relying on the presence of "ischema_names" in the + # dialect module, else we can't test this. we need to be able + # to get the dialect to not be aware of some type so we temporarily + # monkeypatch. not sure what a better way for this could be, + # except for an established dialect hook or dialect-specific tests + if not hasattr(dialect_module, 'ischema_names'): + return + + ischema_names = dialect_module.ischema_names + t.create() + dialect_module.ischema_names = {} + try: + try: + m2 = MetaData(testing.db) + t2 = Table("test", m2, autoload=True) + assert False + except exceptions.SAWarning: + assert True + + @testing.emits_warning('Did not recognize type') + def warns(): + m3 = MetaData(testing.db) + t3 = Table("test", m3, autoload=True) + assert t3.c.foo.type.__class__ == sqltypes.NullType + + finally: + dialect_module.ischema_names = ischema_names + t.drop() + + def test_basic_override(self): + meta = MetaData(testing.db) + table = Table( + 'override_test', meta, + Column('col1', Integer, primary_key=True), + Column('col2', String(20)), + Column('col3', Numeric) + ) + table.create() + + meta2 = MetaData(testing.db) + try: + table = Table( + 'override_test', meta2, + Column('col2', Unicode()), + Column('col4', String(30)), autoload=True) + + self.assert_(isinstance(table.c.col1.type, Integer)) + self.assert_(isinstance(table.c.col2.type, Unicode)) + self.assert_(isinstance(table.c.col4.type, String)) + finally: + table.drop() + + def test_override_pkfk(self): + """test that you can override columns which contain foreign keys to other reflected tables, + where the foreign key column is also a primary key column""" + + meta = MetaData(testing.db) + users = Table('users', meta, + Column('id', Integer, primary_key=True), + Column('name', String(30))) + addresses = Table('addresses', meta, + Column('id', Integer, primary_key=True), + Column('street', String(30))) + + + meta.create_all() + try: + meta2 = MetaData(testing.db) + a2 = Table('addresses', meta2, + Column('id', Integer, ForeignKey('users.id'), primary_key=True), + autoload=True) + u2 = Table('users', meta2, autoload=True) + + assert list(a2.primary_key) == [a2.c.id] + assert list(u2.primary_key) == [u2.c.id] + assert u2.join(a2).onclause == u2.c.id==a2.c.id + + meta3 = MetaData(testing.db) + u3 = Table('users', meta3, autoload=True) + a3 = Table('addresses', meta3, + Column('id', Integer, ForeignKey('users.id'), primary_key=True), + autoload=True) + + assert list(a3.primary_key) == [a3.c.id] + assert list(u3.primary_key) == [u3.c.id] + assert u3.join(a3).onclause == u3.c.id==a3.c.id + + finally: + meta.drop_all() + + def test_override_nonexistent_fk(self): + """test that you can override columns and create new foreign keys to other reflected tables + which have no foreign keys. this is common with MySQL MyISAM tables.""" + + meta = MetaData(testing.db) + users = Table('users', meta, Column('id', Integer, primary_key=True), Column('name', String(30))) addresses = Table('addresses', meta, Column('id', Integer, primary_key=True), Column('street', String(30)), Column('user_id', Integer)) - - meta.create_all() + + meta.create_all() try: - meta2 = MetaData(testbase.db) - a2 = Table('addresses', meta2, + meta2 = MetaData(testing.db) + a2 = Table('addresses', meta2, Column('user_id', Integer, ForeignKey('users.id')), autoload=True) u2 = Table('users', meta2, autoload=True) - - assert len(a2.c.user_id.foreign_keys)>0 + + assert len(a2.c.user_id.foreign_keys) == 1 + assert len(a2.foreign_keys) == 1 + assert [c.parent for c in a2.foreign_keys] == [a2.c.user_id] + assert [c.parent for c in a2.c.user_id.foreign_keys] == [a2.c.user_id] assert list(a2.c.user_id.foreign_keys)[0].parent is a2.c.user_id assert u2.join(a2).onclause == u2.c.id==a2.c.user_id - meta3 = MetaData(testbase.db) + meta3 = MetaData(testing.db) u3 = Table('users', meta3, autoload=True) - a3 = Table('addresses', meta3, + a3 = Table('addresses', meta3, Column('user_id', Integer, ForeignKey('users.id')), autoload=True) - + assert u3.join(a3).onclause == u3.c.id==a3.c.user_id - meta4 = MetaData(testbase.db) + meta4 = MetaData(testing.db) u4 = Table('users', meta4, Column('id', Integer, key='u_id', primary_key=True), autoload=True) @@ -183,166 +230,153 @@ class ReflectionTest(PersistTest): finally: meta.drop_all() - def testoverridecolumns2(self): - """test that you can override columns which contain foreign keys to other reflected tables, - where the foreign key column is also a primary key column""" - meta = MetaData(testbase.db) - users = Table('users', meta, + def test_override_existing_fk(self): + """test that you can override columns and specify new foreign keys to other reflected tables, + on columns which *do* already have that foreign key, and that the FK is not duped. + """ + + meta = MetaData(testing.db) + users = Table('users', meta, Column('id', Integer, primary_key=True), - Column('name', String(30))) + Column('name', String(30)), + test_needs_fk=True) addresses = Table('addresses', meta, - Column('id', Integer, primary_key=True), - Column('street', String(30))) - + Column('id', Integer,primary_key=True), + Column('user_id', Integer, ForeignKey('users.id')), + test_needs_fk=True) - meta.create_all() + meta.create_all() try: - meta2 = MetaData(testbase.db) - a2 = Table('addresses', meta2, - Column('id', Integer, ForeignKey('users.id'), primary_key=True, ), + meta2 = MetaData(testing.db) + a2 = Table('addresses', meta2, + Column('user_id',Integer, ForeignKey('users.id')), autoload=True) u2 = Table('users', meta2, autoload=True) - assert list(a2.primary_key) == [a2.c.id] - assert list(u2.primary_key) == [u2.c.id] - assert u2.join(a2).onclause == u2.c.id==a2.c.id + assert len(a2.foreign_keys) == 1 + assert len(a2.c.user_id.foreign_keys) == 1 + assert len(a2.constraints) == 2 + assert [c.parent for c in a2.foreign_keys] == [a2.c.user_id] + assert [c.parent for c in a2.c.user_id.foreign_keys] == [a2.c.user_id] + assert list(a2.c.user_id.foreign_keys)[0].parent is a2.c.user_id + assert u2.join(a2).onclause == u2.c.id==a2.c.user_id - # heres what was originally failing, because a2's primary key - # had two "id" columns, one of which was not part of a2's "c" collection - #class Address(object):pass - #mapper(Address, a2) - #add1 = Address() - #sess = create_session() - #sess.save(add1) - #sess.flush() - - meta3 = MetaData(testbase.db) - u3 = Table('users', meta3, autoload=True) - a3 = Table('addresses', meta3, - Column('id', Integer, ForeignKey('users.id'), primary_key=True), + meta2 = MetaData(testing.db) + u2 = Table('users', meta2, + Column('id', Integer, primary_key=True), + autoload=True) + a2 = Table('addresses', meta2, + Column('id', Integer, primary_key=True), + Column('user_id',Integer, ForeignKey('users.id')), autoload=True) - assert list(a3.primary_key) == [a3.c.id] - assert list(u3.primary_key) == [u3.c.id] - assert u3.join(a3).onclause == u3.c.id==a3.c.id - + assert len(a2.foreign_keys) == 1 + assert len(a2.c.user_id.foreign_keys) == 1 + assert len(a2.constraints) == 2 + assert [c.parent for c in a2.foreign_keys] == [a2.c.user_id] + assert [c.parent for c in a2.c.user_id.foreign_keys] == [a2.c.user_id] + assert list(a2.c.user_id.foreign_keys)[0].parent is a2.c.user_id + assert u2.join(a2).onclause == u2.c.id==a2.c.user_id finally: meta.drop_all() - - @testing.supported('mysql') - def testmysqltypes(self): - meta1 = MetaData(testbase.db) - table = Table( - 'mysql_types', meta1, + + def test_use_existing(self): + meta = MetaData(testing.db) + users = Table('users', meta, Column('id', Integer, primary_key=True), - Column('num1', mysql.MSInteger(unsigned=True)), - Column('text1', mysql.MSLongText), - Column('text2', mysql.MSLongText()), - Column('num2', mysql.MSBigInteger), - Column('num3', mysql.MSBigInteger()), - Column('num4', mysql.MSDouble), - Column('num5', mysql.MSDouble()), - Column('enum1', mysql.MSEnum('"black"', '"white"')), - ) + Column('name', String(30)), + test_needs_fk=True) + addresses = Table('addresses', meta, + Column('id', Integer,primary_key=True), + Column('user_id', Integer, ForeignKey('users.id')), + Column('data', String(100)), + test_needs_fk=True) + + meta.create_all() try: - table.drop(checkfirst=True) - table.create() - meta2 = MetaData(testbase.db) - t2 = Table('mysql_types', meta2, autoload=True) - assert isinstance(t2.c.num1.type, mysql.MSInteger) - assert t2.c.num1.type.unsigned - assert isinstance(t2.c.text1.type, mysql.MSLongText) - assert isinstance(t2.c.text2.type, mysql.MSLongText) - assert isinstance(t2.c.num2.type, mysql.MSBigInteger) - assert isinstance(t2.c.num3.type, mysql.MSBigInteger) - assert isinstance(t2.c.num4.type, mysql.MSDouble) - assert isinstance(t2.c.num5.type, mysql.MSDouble) - assert isinstance(t2.c.enum1.type, mysql.MSEnum) - t2.drop() - t2.create() + meta2 = MetaData(testing.db) + addresses = Table('addresses', meta2, Column('data', Unicode), autoload=True) + try: + users = Table('users', meta2, Column('name', Unicode), autoload=True) + assert False + except exceptions.InvalidRequestError, err: + assert str(err) == "Table 'users' is already defined for this MetaData instance. Specify 'useexisting=True' to redefine options and columns on an existing Table object." + + users = Table('users', meta2, Column('name', Unicode), autoload=True, useexisting=True) + assert isinstance(users.c.name.type, Unicode) + + assert not users.quote + + users = Table('users', meta2, quote=True, autoload=True, useexisting=True) + assert users.quote + finally: - table.drop(checkfirst=True) + meta.drop_all() - def test_pks_not_uniques(self): """test that primary key reflection not tripped up by unique indexes""" - testbase.db.execute(""" + + testing.db.execute(""" CREATE TABLE book ( id INTEGER NOT NULL, title VARCHAR(100) NOT NULL, - series INTEGER NULL, - series_id INTEGER NULL, + series INTEGER, + series_id INTEGER, UNIQUE(series, series_id), PRIMARY KEY(id) )""") try: - metadata = MetaData(bind=testbase.db) + metadata = MetaData(bind=testing.db) book = Table('book', metadata, autoload=True) assert book.c.id in book.primary_key assert book.c.series not in book.primary_key assert len(book.primary_key) == 1 finally: - testbase.db.execute("drop table book") + testing.db.execute("drop table book") + + def test_fk_error(self): + metadata = MetaData(testing.db) + slots_table = Table('slots', metadata, + Column('slot_id', Integer, primary_key=True), + Column('pkg_id', Integer, ForeignKey('pkgs.pkg_id')), + Column('slot', String(128)), + ) + try: + metadata.create_all() + assert False + except exceptions.InvalidRequestError, err: + assert str(err) == "Could not find table 'pkgs' with which to generate a foreign key" def test_composite_pks(self): """test reflection of a composite primary key""" - testbase.db.execute(""" + + testing.db.execute(""" CREATE TABLE book ( id INTEGER NOT NULL, isbn VARCHAR(50) NOT NULL, title VARCHAR(100) NOT NULL, - series INTEGER NULL, - series_id INTEGER NULL, + series INTEGER, + series_id INTEGER, UNIQUE(series, series_id), PRIMARY KEY(id, isbn) )""") try: - metadata = MetaData(bind=testbase.db) + metadata = MetaData(bind=testing.db) book = Table('book', metadata, autoload=True) assert book.c.id in book.primary_key assert book.c.isbn in book.primary_key assert book.c.series not in book.primary_key assert len(book.primary_key) == 2 finally: - testbase.db.execute("drop table book") - - @testing.supported('sqlite') - def test_goofy_sqlite(self): - """test autoload of table where quotes were used with all the colnames. quirky in sqlite.""" - testbase.db.execute("""CREATE TABLE "django_content_type" ( - "id" integer NOT NULL PRIMARY KEY, - "django_stuff" text NULL - ) - """) - testbase.db.execute(""" - CREATE TABLE "django_admin_log" ( - "id" integer NOT NULL PRIMARY KEY, - "action_time" datetime NOT NULL, - "content_type_id" integer NULL REFERENCES "django_content_type" ("id"), - "object_id" text NULL, - "change_message" text NOT NULL - ) - """) - try: - meta = MetaData(testbase.db) - table1 = Table("django_admin_log", meta, autoload=True) - table2 = Table("django_content_type", meta, autoload=True) - j = table1.join(table2) - assert j.onclause == table1.c.content_type_id==table2.c.id - finally: - testbase.db.execute("drop table django_admin_log") - testbase.db.execute("drop table django_content_type") + testing.db.execute("drop table book") + @testing.exclude('mysql', '<', (4, 1, 1)) def test_composite_fk(self): """test reflection of composite foreign keys""" - if (testbase.db.engine.name == 'mysql' and - testbase.db.dialect.get_version_info(testbase.db) < (4, 1, 1)): - return - meta = MetaData(testbase.db) - - table = Table( - 'multi', meta, + meta = MetaData(testing.db) + multi = Table( + 'multi', meta, Column('multi_id', Integer, primary_key=True), Column('multi_rev', Integer, primary_key=True), Column('multi_hoho', Integer, primary_key=True), @@ -350,7 +384,7 @@ class ReflectionTest(PersistTest): Column('val', String(100)), test_needs_fk=True, ) - table2 = Table('multi2', meta, + multi2 = Table('multi2', meta, Column('id', Integer, primary_key=True), Column('foo', Integer), Column('bar', Integer), @@ -359,159 +393,43 @@ class ReflectionTest(PersistTest): ForeignKeyConstraint(['foo', 'bar', 'lala'], ['multi.multi_id', 'multi.multi_rev', 'multi.multi_hoho']), test_needs_fk=True, ) - assert table.c.multi_hoho meta.create_all() - meta.clear() - + try: - table = Table('multi', meta, autoload=True) - table2 = Table('multi2', meta, autoload=True) - - print table - print table2 + meta2 = MetaData() + table = Table('multi', meta2, autoload=True, autoload_with=testing.db) + table2 = Table('multi2', meta2, autoload=True, autoload_with=testing.db) + self.assert_tables_equal(multi, table) + self.assert_tables_equal(multi2, table2) j = join(table, table2) - print str(j.onclause) self.assert_(and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar, table.c.multi_hoho==table2.c.lala).compare(j.onclause)) - finally: meta.drop_all() - def test_to_metadata(self): - meta = MetaData() - - table = Table('mytable', meta, - Column('myid', Integer, primary_key=True), - Column('name', String(40), nullable=False), - Column('description', String(30), CheckConstraint("description='hi'")), - UniqueConstraint('name'), - mysql_engine='InnoDB' - ) - - table2 = Table('othertable', meta, - Column('id', Integer, primary_key=True), - Column('myid', Integer, ForeignKey('mytable.myid')), - mysql_engine='InnoDB' - ) - - def test_to_metadata(): - meta2 = MetaData() - table_c = table.tometadata(meta2) - table2_c = table2.tometadata(meta2) - return (table_c, table2_c) - - def test_pickle(): - meta.connect(testbase.db) - meta2 = pickle.loads(pickle.dumps(meta)) - assert meta2.bind is None - return (meta2.tables['mytable'], meta2.tables['othertable']) - - def test_pickle_via_reflect(): - # this is the most common use case, pickling the results of a - # database reflection - meta2 = MetaData(bind=testbase.db) - t1 = Table('mytable', meta2, autoload=True) - t2 = Table('othertable', meta2, autoload=True) - meta3 = pickle.loads(pickle.dumps(meta2)) - assert meta3.bind is None - assert meta3.tables['mytable'] is not t1 - return (meta3.tables['mytable'], meta3.tables['othertable']) - - meta.create_all(testbase.db) - try: - for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)): - table_c, table2_c = test() - assert table is not table_c - assert table_c.c.myid.primary_key - assert isinstance(table_c.c.myid.type, Integer) - assert isinstance(table_c.c.name.type, String) - assert not table_c.c.name.nullable - assert table_c.c.description.nullable - assert table.primary_key is not table_c.primary_key - assert [x.name for x in table.primary_key] == [x.name for x in table_c.primary_key] - assert list(table2_c.c.myid.foreign_keys)[0].column is table_c.c.myid - assert list(table2_c.c.myid.foreign_keys)[0].column is not table.c.myid - - # constraints dont get reflected for any dialect right now - if has_constraints: - for c in table_c.c.description.constraints: - if isinstance(c, CheckConstraint): - break - else: - assert False - assert c.sqltext=="description='hi'" - - for c in table_c.constraints: - if isinstance(c, UniqueConstraint): - break - else: - assert False - assert c.columns.contains_column(table_c.c.name) - assert not c.columns.contains_column(table.c.name) - finally: - meta.drop_all(testbase.db) - - def test_nonexistent(self): - self.assertRaises(NoSuchTableError, Table, - 'fake_table', - MetaData(testbase.db), autoload=True) - - def testoverride(self): - meta = MetaData(testbase.db) - table = Table( - 'override_test', meta, - Column('col1', Integer, primary_key=True), - Column('col2', String(20)), - Column('col3', Numeric) - ) - table.create() - # clear out table registry - - meta2 = MetaData(testbase.db) - try: - table = Table( - 'override_test', meta2, - Column('col2', Unicode()), - Column('col4', String(30)), autoload=True) - - print repr(table) - self.assert_(isinstance(table.c.col1.type, Integer)) - self.assert_(isinstance(table.c.col2.type, Unicode)) - self.assert_(isinstance(table.c.col4.type, String)) - finally: - table.drop() - - @testing.supported('mssql') - def testidentity(self): - meta = MetaData(testbase.db) - table = Table( - 'identity_test', meta, - Column('col1', Integer, Sequence('fred', 2, 3), primary_key=True) - ) - table.create() - - meta2 = MetaData(testbase.db) - try: - table2 = Table('identity_test', meta2, autoload=True) - assert table2.c['col1'].sequence.start == 2 - assert table2.c['col1'].sequence.increment == 3 - finally: - table.drop() + @testing.unsupported('oracle') def testreserved(self): # check a table that uses an SQL reserved name doesn't cause an error - meta = MetaData(testbase.db) - table_a = Table('select', meta, + meta = MetaData(testing.db) + table_a = Table('select', meta, Column('not', Integer, primary_key=True), Column('from', String(12), nullable=False), UniqueConstraint('from', name='when')) Index('where', table_a.c['from']) + # There's currently no way to calculate identifier case normalization + # in isolation, so... + if testing.against('firebird', 'oracle', 'maxdb'): + check_col = 'TRUE' + else: + check_col = 'true' quoter = meta.bind.dialect.identifier_preparer.quote_identifier table_b = Table('false', meta, Column('create', Integer, primary_key=True), Column('true', Integer, ForeignKey('select.not')), - CheckConstraint('%s <> 1' % quoter('true'), name='limit')) + CheckConstraint('%s <> 1' % quoter(check_col), + name='limit')) table_c = Table('is', meta, Column('or', Integer, nullable=False, primary_key=True), @@ -520,12 +438,11 @@ class ReflectionTest(PersistTest): index_c = Index('else', table_c.c.join) - #meta.bind.echo = True meta.create_all() index_c.drop() - - meta2 = MetaData(testbase.db) + + meta2 = MetaData(testing.db) try: table_a2 = Table('select', meta2, autoload=True) table_b2 = Table('false', meta2, autoload=True) @@ -533,12 +450,78 @@ class ReflectionTest(PersistTest): finally: meta.drop_all() -class CreateDropTest(PersistTest): + def test_reflect_all(self): + existing = testing.db.table_names() + + names = ['rt_%s' % name for name in ('a','b','c','d','e')] + nameset = set(names) + for name in names: + # be sure our starting environment is sane + self.assert_(name not in existing) + self.assert_('rt_f' not in existing) + + baseline = MetaData(testing.db) + for name in names: + Table(name, baseline, Column('id', Integer, primary_key=True)) + baseline.create_all() + + try: + m1 = MetaData(testing.db) + self.assert_(not m1.tables) + m1.reflect() + self.assert_(nameset.issubset(set(m1.tables.keys()))) + + m2 = MetaData() + m2.reflect(testing.db, only=['rt_a', 'rt_b']) + self.assert_(set(m2.tables.keys()) == set(['rt_a', 'rt_b'])) + + m3 = MetaData() + c = testing.db.connect() + m3.reflect(bind=c, only=lambda name, meta: name == 'rt_c') + self.assert_(set(m3.tables.keys()) == set(['rt_c'])) + + m4 = MetaData(testing.db) + try: + m4.reflect(only=['rt_a', 'rt_f']) + self.assert_(False) + except exceptions.InvalidRequestError, e: + self.assert_(e.args[0].endswith('(rt_f)')) + + m5 = MetaData(testing.db) + m5.reflect(only=[]) + self.assert_(not m5.tables) + + m6 = MetaData(testing.db) + m6.reflect(only=lambda n, m: False) + self.assert_(not m6.tables) + + m7 = MetaData(testing.db, reflect=True) + self.assert_(nameset.issubset(set(m7.tables.keys()))) + + try: + m8 = MetaData(reflect=True) + self.assert_(False) + except exceptions.ArgumentError, e: + self.assert_( + e.args[0] == + "A bind must be supplied in conjunction with reflect=True") + finally: + baseline.drop_all() + + if existing: + print "Other tables present in database, skipping some checks." + else: + m9 = MetaData(testing.db) + m9.reflect() + self.assert_(not m9.tables) + + +class CreateDropTest(TestBase): def setUpAll(self): global metadata, users metadata = MetaData() users = Table('users', metadata, - Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), + Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True), Column('user_name', String(40)), ) @@ -546,7 +529,6 @@ class CreateDropTest(PersistTest): Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True), Column('user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(40)), - ) orders = Table('orders', metadata, @@ -554,14 +536,12 @@ class CreateDropTest(PersistTest): Column('user_id', Integer, ForeignKey(users.c.user_id)), Column('description', String(50)), Column('isopen', Integer), - ) orderitems = Table('items', metadata, Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True), Column('order_id', INT, ForeignKey("orders")), Column('item_name', VARCHAR(50)), - ) def test_sorter( self ): @@ -571,40 +551,100 @@ class CreateDropTest(PersistTest): def testcheckfirst(self): try: - assert not users.exists(testbase.db) - users.create(bind=testbase.db) - assert users.exists(testbase.db) - users.create(bind=testbase.db, checkfirst=True) - users.drop(bind=testbase.db) - users.drop(bind=testbase.db, checkfirst=True) - assert not users.exists(bind=testbase.db) - users.create(bind=testbase.db, checkfirst=True) - users.drop(bind=testbase.db) + assert not users.exists(testing.db) + users.create(bind=testing.db) + assert users.exists(testing.db) + users.create(bind=testing.db, checkfirst=True) + users.drop(bind=testing.db) + users.drop(bind=testing.db, checkfirst=True) + assert not users.exists(bind=testing.db) + users.create(bind=testing.db, checkfirst=True) + users.drop(bind=testing.db) finally: - metadata.drop_all(bind=testbase.db) + metadata.drop_all(bind=testing.db) + @testing.exclude('mysql', '<', (4, 1, 1)) def test_createdrop(self): - metadata.create_all(bind=testbase.db) - self.assertEqual( testbase.db.has_table('items'), True ) - self.assertEqual( testbase.db.has_table('email_addresses'), True ) - metadata.create_all(bind=testbase.db) - self.assertEqual( testbase.db.has_table('items'), True ) - - metadata.drop_all(bind=testbase.db) - self.assertEqual( testbase.db.has_table('items'), False ) - self.assertEqual( testbase.db.has_table('email_addresses'), False ) - metadata.drop_all(bind=testbase.db) - self.assertEqual( testbase.db.has_table('items'), False ) - -class SchemaTest(PersistTest): - # this test should really be in the sql tests somewhere, not engine - @testing.unsupported('sqlite') - def testiteration(self): + metadata.create_all(bind=testing.db) + self.assertEqual( testing.db.has_table('items'), True ) + self.assertEqual( testing.db.has_table('email_addresses'), True ) + metadata.create_all(bind=testing.db) + self.assertEqual( testing.db.has_table('items'), True ) + + metadata.drop_all(bind=testing.db) + self.assertEqual( testing.db.has_table('items'), False ) + self.assertEqual( testing.db.has_table('email_addresses'), False ) + metadata.drop_all(bind=testing.db) + self.assertEqual( testing.db.has_table('items'), False ) + + def test_tablenames(self): + from sqlalchemy.util import Set + metadata.create_all(bind=testing.db) + # we only check to see if all the explicitly created tables are there, rather than + # assertEqual -- the test db could have "extra" tables if there is a misconfigured + # template. (*cough* tsearch2 w/ the pg windows installer.) + self.assert_(not Set(metadata.tables) - Set(testing.db.table_names())) + metadata.drop_all(bind=testing.db) + +class SchemaManipulationTest(TestBase): + def test_append_constraint_unique(self): + meta = MetaData() + + users = Table('users', meta, Column('id', Integer)) + addresses = Table('addresses', meta, Column('id', Integer), Column('user_id', Integer)) + + fk = ForeignKeyConstraint(['user_id'],[users.c.id]) + + addresses.append_constraint(fk) + addresses.append_constraint(fk) + assert len(addresses.c.user_id.foreign_keys) == 1 + assert addresses.constraints == set([addresses.primary_key, fk]) + +class UnicodeReflectionTest(TestBase): + + def test_basic(self): + try: + # the 'convert_unicode' should not get in the way of the reflection + # process. reflecttable for oracle, postgres (others?) expect non-unicode + # strings in result sets/bind params + bind = engines.utf8_engine(options={'convert_unicode':True}) + metadata = MetaData(bind) + + if testing.against('sybase', 'maxdb', 'oracle'): + names = set(['plain']) + else: + names = set([u'plain', u'Unit\u00e9ble', u'\u6e2c\u8a66']) + + for name in names: + Table(name, metadata, Column('id', Integer, Sequence(name + "_id_seq"), primary_key=True)) + metadata.create_all() + + reflected = set(bind.table_names()) + if not names.issubset(reflected): + # Python source files in the utf-8 coding seem to normalize + # literals as NFC (and the above are explicitly NFC). Maybe + # this database normalizes NFD on reflection. + nfc = set([unicodedata.normalize('NFC', n) for n in names]) + self.assert_(nfc == names) + # Yep. But still ensure that bulk reflection and create/drop + # work with either normalization. + + r = MetaData(bind, reflect=True) + r.drop_all() + r.create_all() + finally: + metadata.drop_all() + bind.dispose() + + +class SchemaTest(TestBase): + + def test_iteration(self): metadata = MetaData() - table1 = Table('table1', metadata, + table1 = Table('table1', metadata, Column('col1', Integer, primary_key=True), schema='someschema') - table2 = Table('table2', metadata, + table2 = Table('table2', metadata, Column('col1', Integer, primary_key=True), Column('col2', Integer, ForeignKey('someschema.table1.col1')), schema='someschema') @@ -613,40 +653,68 @@ class SchemaTest(PersistTest): buf = StringIO.StringIO() def foo(s, p=None): buf.write(s) - gen = create_engine(testbase.db.name + "://", strategy="mock", executor=foo) - gen = gen.dialect.schemagenerator(gen) + gen = create_engine(testing.db.name + "://", strategy="mock", executor=foo) + gen = gen.dialect.schemagenerator(gen.dialect, gen) gen.traverse(table1) gen.traverse(table2) buf = buf.getvalue() print buf - assert buf.index("CREATE TABLE someschema.table1") > -1 - assert buf.index("CREATE TABLE someschema.table2") > -1 - - @testing.supported('mysql','postgres') - def testcreate(self): - engine = testbase.db - schema = engine.dialect.get_default_schema_name(engine) - #engine.echo = True - - if testbase.db.name == 'mysql': - schema = testbase.db.url.database + if testing.db.dialect.preparer(testing.db.dialect).omit_schema: + assert buf.index("CREATE TABLE table1") > -1 + assert buf.index("CREATE TABLE table2") > -1 else: + assert buf.index("CREATE TABLE someschema.table1") > -1 + assert buf.index("CREATE TABLE someschema.table2") > -1 + + @testing.unsupported('sqlite', 'firebird') + # fixme: revisit these below. + @testing.fails_on('oracle', 'mssql', 'sybase', 'access') + def test_explicit_default_schema(self): + engine = testing.db + + if testing.against('mysql'): + schema = testing.db.url.database + elif testing.against('postgres'): schema = 'public' - metadata = MetaData(testbase.db) - table1 = Table('table1', metadata, - Column('col1', Integer, primary_key=True), - schema=schema) - table2 = Table('table2', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', Integer, ForeignKey('%s.table1.col1' % schema)), - schema=schema) - metadata.create_all() - metadata.create_all(checkfirst=True) - metadata.clear() - table1 = Table('table1', metadata, autoload=True, schema=schema) - table2 = Table('table2', metadata, autoload=True, schema=schema) - metadata.drop_all() - + else: + schema = engine.dialect.get_default_schema_name(engine.connect()) + + metadata = MetaData(engine) + table1 = Table('table1', metadata, + Column('col1', Integer, primary_key=True), + schema=schema) + table2 = Table('table2', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', Integer, + ForeignKey('%s.table1.col1' % schema)), + schema=schema) + try: + metadata.create_all() + metadata.create_all(checkfirst=True) + metadata.clear() + + table1 = Table('table1', metadata, autoload=True, schema=schema) + table2 = Table('table2', metadata, autoload=True, schema=schema) + finally: + metadata.drop_all() + + +class HasSequenceTest(TestBase): + def setUpAll(self): + global metadata, users + metadata = MetaData() + users = Table('users', metadata, + Column('user_id', Integer, Sequence('user_id_seq'), primary_key=True), + Column('user_name', String(40)), + ) + + @testing.unsupported('sqlite', 'mysql', 'mssql', 'access', 'sybase') + def test_hassequence(self): + metadata.create_all(bind=testing.db) + self.assertEqual(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True) + metadata.drop_all(bind=testing.db) + self.assertEqual(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False) + + if __name__ == "__main__": - testbase.main() - + testenv.main() diff --git a/test/engine/transaction.py b/test/engine/transaction.py index 593a069a96..edae14da29 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() import sys, time, threading from sqlalchemy import * @@ -6,7 +6,7 @@ from sqlalchemy.orm import * from testlib import * -class TransactionTest(PersistTest): +class TransactionTest(TestBase): def setUpAll(self): global users, metadata metadata = MetaData() @@ -15,15 +15,15 @@ class TransactionTest(PersistTest): Column('user_name', VARCHAR(20)), test_needs_acid=True, ) - users.create(testbase.db) - + users.create(testing.db) + def tearDown(self): - testbase.db.connect().execute(users.delete()) + testing.db.connect().execute(users.delete()) def tearDownAll(self): - users.drop(testbase.db) - + users.drop(testing.db) + def testcommits(self): - connection = testbase.db.connect() + connection = testing.db.connect() transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name='user1') transaction.commit() @@ -37,23 +37,23 @@ class TransactionTest(PersistTest): result = connection.execute("select * from query_users") assert len(result.fetchall()) == 3 transaction.commit() - + def testrollback(self): """test a basic rollback""" - connection = testbase.db.connect() + connection = testing.db.connect() transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name='user1') connection.execute(users.insert(), user_id=2, user_name='user2') connection.execute(users.insert(), user_id=3, user_name='user3') transaction.rollback() - + result = connection.execute("select * from query_users") assert len(result.fetchall()) == 0 connection.close() def testraise(self): - connection = testbase.db.connect() - + connection = testing.db.connect() + transaction = connection.begin() try: connection.execute(users.insert(), user_id=1, user_name='user1') @@ -64,14 +64,15 @@ class TransactionTest(PersistTest): except Exception , e: print "Exception: ", e transaction.rollback() - + result = connection.execute("select * from query_users") assert len(result.fetchall()) == 0 connection.close() - + + @testing.exclude('mysql', '<', (5, 0, 3)) def testnestedrollback(self): - connection = testbase.db.connect() - + connection = testing.db.connect() + try: transaction = connection.begin() try: @@ -96,10 +97,11 @@ class TransactionTest(PersistTest): assert str(e) == 'uh oh' # and not "This transaction is inactive" finally: connection.close() - + + @testing.exclude('mysql', '<', (5, 0, 3)) def testnesting(self): - connection = testbase.db.connect() + connection = testing.db.connect() transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name='user1') connection.execute(users.insert(), user_id=2, user_name='user2') @@ -114,10 +116,54 @@ class TransactionTest(PersistTest): result = connection.execute("select * from query_users") assert len(result.fetchall()) == 0 connection.close() - - @testing.unsupported('sqlite') + + @testing.exclude('mysql', '<', (5, 0, 3)) + def testclose(self): + connection = testing.db.connect() + transaction = connection.begin() + connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=3, user_name='user3') + trans2 = connection.begin() + connection.execute(users.insert(), user_id=4, user_name='user4') + connection.execute(users.insert(), user_id=5, user_name='user5') + assert connection.in_transaction() + trans2.close() + assert connection.in_transaction() + transaction.commit() + assert not connection.in_transaction() + self.assert_(connection.scalar("select count(1) from query_users") == 5) + + result = connection.execute("select * from query_users") + assert len(result.fetchall()) == 5 + connection.close() + + @testing.exclude('mysql', '<', (5, 0, 3)) + def testclose2(self): + connection = testing.db.connect() + transaction = connection.begin() + connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=3, user_name='user3') + trans2 = connection.begin() + connection.execute(users.insert(), user_id=4, user_name='user4') + connection.execute(users.insert(), user_id=5, user_name='user5') + assert connection.in_transaction() + trans2.close() + assert connection.in_transaction() + transaction.close() + assert not connection.in_transaction() + self.assert_(connection.scalar("select count(1) from query_users") == 0) + + result = connection.execute("select * from query_users") + assert len(result.fetchall()) == 0 + connection.close() + + + @testing.unsupported('sqlite', 'mssql', 'sybase', 'access') + @testing.exclude('mysql', '<', (5, 0, 3)) def testnestedsubtransactionrollback(self): - connection = testbase.db.connect() + connection = testing.db.connect() transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name='user1') trans2 = connection.begin_nested() @@ -125,16 +171,17 @@ class TransactionTest(PersistTest): trans2.rollback() connection.execute(users.insert(), user_id=3, user_name='user3') transaction.commit() - + self.assertEquals( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(3,)] ) connection.close() - @testing.unsupported('sqlite') + @testing.unsupported('sqlite', 'mssql', 'sybase', 'access') + @testing.exclude('mysql', '<', (5, 0, 3)) def testnestedsubtransactioncommit(self): - connection = testbase.db.connect() + connection = testing.db.connect() transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name='user1') trans2 = connection.begin_nested() @@ -142,16 +189,17 @@ class TransactionTest(PersistTest): trans2.commit() connection.execute(users.insert(), user_id=3, user_name='user3') transaction.commit() - + self.assertEquals( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(2,),(3,)] ) connection.close() - @testing.unsupported('sqlite') + @testing.unsupported('sqlite', 'mssql', 'sybase', 'access') + @testing.exclude('mysql', '<', (5, 0, 3)) def testrollbacktosubtransaction(self): - connection = testbase.db.connect() + connection = testing.db.connect() transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name='user1') trans2 = connection.begin_nested() @@ -161,98 +209,105 @@ class TransactionTest(PersistTest): trans3.rollback() connection.execute(users.insert(), user_id=4, user_name='user4') transaction.commit() - + self.assertEquals( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(4,)] ) connection.close() - - @testing.supported('postgres', 'mysql') + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) def testtwophasetransaction(self): - connection = testbase.db.connect() - + connection = testing.db.connect() + transaction = connection.begin_twophase() connection.execute(users.insert(), user_id=1, user_name='user1') transaction.prepare() transaction.commit() - + transaction = connection.begin_twophase() connection.execute(users.insert(), user_id=2, user_name='user2') transaction.commit() - + transaction = connection.begin_twophase() connection.execute(users.insert(), user_id=3, user_name='user3') transaction.rollback() - + transaction = connection.begin_twophase() connection.execute(users.insert(), user_id=4, user_name='user4') transaction.prepare() transaction.rollback() - + self.assertEquals( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(2,)] ) connection.close() - @testing.supported('postgres', 'mysql') - def testmixedtransaction(self): - connection = testbase.db.connect() - + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) + def testmixedtwophasetransaction(self): + connection = testing.db.connect() + transaction = connection.begin_twophase() connection.execute(users.insert(), user_id=1, user_name='user1') - + transaction2 = connection.begin() connection.execute(users.insert(), user_id=2, user_name='user2') - + transaction3 = connection.begin_nested() connection.execute(users.insert(), user_id=3, user_name='user3') - + transaction4 = connection.begin() connection.execute(users.insert(), user_id=4, user_name='user4') transaction4.commit() - + transaction3.rollback() - + connection.execute(users.insert(), user_id=5, user_name='user5') - + transaction2.commit() - + transaction.prepare() - + transaction.commit() - + self.assertEquals( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(2,),(5,)] ) connection.close() - - @testing.supported('postgres') + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + # fixme: see if this is still true and/or can be convert to fails_on() + @testing.unsupported('mysql') def testtwophaserecover(self): # MySQL recovery doesn't currently seem to work correctly # Prepared transactions disappear when connections are closed and even # when they aren't it doesn't seem possible to use the recovery id. - connection = testbase.db.connect() - + connection = testing.db.connect() + transaction = connection.begin_twophase() connection.execute(users.insert(), user_id=1, user_name='user1') transaction.prepare() - + connection.close() - connection2 = testbase.db.connect() - + connection2 = testing.db.connect() + self.assertEquals( connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [] ) - + recoverables = connection2.recover_twophase() self.assertTrue( transaction.xid in recoverables ) - + connection2.commit_prepared(transaction.xid, recover=True) self.assertEquals( @@ -261,19 +316,49 @@ class TransactionTest(PersistTest): ) connection2.close() -class AutoRollbackTest(PersistTest): + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) + def testmultipletwophase(self): + conn = testing.db.connect() + + xa = conn.begin_twophase() + conn.execute(users.insert(), user_id=1, user_name='user1') + xa.prepare() + xa.commit() + + xa = conn.begin_twophase() + conn.execute(users.insert(), user_id=2, user_name='user2') + xa.prepare() + xa.rollback() + + xa = conn.begin_twophase() + conn.execute(users.insert(), user_id=3, user_name='user3') + xa.rollback() + + xa = conn.begin_twophase() + conn.execute(users.insert(), user_id=4, user_name='user4') + xa.prepare() + xa.commit() + + result = conn.execute(select([users.c.user_name]).order_by(users.c.user_id)) + self.assertEqual(result.fetchall(), [('user1',),('user4',)]) + + conn.close() + +class AutoRollbackTest(TestBase): def setUpAll(self): global metadata metadata = MetaData() - + def tearDownAll(self): - metadata.drop_all(testbase.db) - + metadata.drop_all(testing.db) + @testing.unsupported('sqlite') def testrollback_deadlock(self): """test that returning connections to the pool clears any object locks.""" - conn1 = testbase.db.connect() - conn2 = testbase.db.connect() + conn1 = testing.db.connect() + conn2 = testing.db.connect() users = Table('deadlock_users', metadata, Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), @@ -282,19 +367,96 @@ class AutoRollbackTest(PersistTest): users.create(conn1) conn1.execute("select * from deadlock_users") conn1.close() - # without auto-rollback in the connection pool's return() logic, this deadlocks in Postgres, - # because conn1 is returned to the pool but still has a lock on "deadlock_users" + + # without auto-rollback in the connection pool's return() logic, this + # deadlocks in Postgres, because conn1 is returned to the pool but + # still has a lock on "deadlock_users". # comment out the rollback in pool/ConnectionFairy._close() to see ! users.drop(conn2) conn2.close() -class TLTransactionTest(PersistTest): +class ExplicitAutoCommitTest(TestBase): + """test the 'autocommit' flag on select() and text() objects. + + Requires Postgres so that we may define a custom function which modifies the database. + """ + + __only_on__ = 'postgres' + + def setUpAll(self): + global metadata, foo + metadata = MetaData(testing.db) + foo = Table('foo', metadata, Column('id', Integer, primary_key=True), Column('data', String(100))) + metadata.create_all() + testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql") + + def tearDown(self): + foo.delete().execute() + + def tearDownAll(self): + testing.db.execute("drop function insert_foo(varchar)") + metadata.drop_all() + + def test_control(self): + # test that not using autocommit does not commit + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + conn1.execute(select([func.insert_foo('data1')])) + assert conn2.execute(select([foo.c.data])).fetchall() == [] + + conn1.execute(text("select insert_foo('moredata')")) + assert conn2.execute(select([foo.c.data])).fetchall() == [] + + trans = conn1.begin() + trans.commit() + + assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',), ('moredata',)] + + conn1.close() + conn2.close() + + def test_explicit_compiled(self): + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + conn1.execute(select([func.insert_foo('data1')], autocommit=True)) + assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',)] + + conn1.execute(select([func.insert_foo('data2')]).autocommit()) + assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',), ('data2',)] + + conn1.close() + conn2.close() + + def test_explicit_text(self): + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + conn1.execute(text("select insert_foo('moredata')", autocommit=True)) + assert conn2.execute(select([foo.c.data])).fetchall() == [('moredata',)] + + conn1.close() + conn2.close() + + def test_implicit_text(self): + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + conn1.execute(text("insert into foo (data) values ('implicitdata')")) + assert conn2.execute(select([foo.c.data])).fetchall() == [('implicitdata',)] + + conn1.close() + conn2.close() + + +class TLTransactionTest(TestBase): def setUpAll(self): global users, metadata, tlengine - tlengine = create_engine(testbase.db.url, strategy='threadlocal') + tlengine = create_engine(testing.db.url, strategy='threadlocal') metadata = MetaData() users = Table('query_users', metadata, - Column('user_id', INT, primary_key = True), + Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True), Column('user_name', VARCHAR(20)), test_needs_acid=True, ) @@ -304,7 +466,42 @@ class TLTransactionTest(PersistTest): def tearDownAll(self): users.drop(tlengine) tlengine.dispose() - + + def test_connection_close(self): + """test that when connections are closed for real, transactions are rolled back and disposed.""" + + c = tlengine.contextual_connect() + c.begin() + assert tlengine.session.in_transaction() + assert hasattr(tlengine.session, '_TLSession__transaction') + assert hasattr(tlengine.session, '_TLSession__trans') + c.close() + assert not tlengine.session.in_transaction() + assert not hasattr(tlengine.session, '_TLSession__transaction') + assert not hasattr(tlengine.session, '_TLSession__trans') + + def test_transaction_close(self): + c = tlengine.contextual_connect() + t = c.begin() + tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.execute(users.insert(), user_id=2, user_name='user2') + t2 = c.begin() + tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.execute(users.insert(), user_id=4, user_name='user4') + t2.close() + + result = c.execute("select * from query_users") + assert len(result.fetchall()) == 4 + + t.close() + + external_connection = tlengine.connect() + result = external_connection.execute("select * from query_users") + try: + assert len(result.fetchall()) == 0 + finally: + external_connection.close() + def testrollback(self): """test a basic rollback""" tlengine.begin() @@ -336,6 +533,8 @@ class TLTransactionTest(PersistTest): external_connection.close() def testcommits(self): + assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0 + connection = tlengine.contextual_connect() transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name='user1') @@ -348,7 +547,8 @@ class TLTransactionTest(PersistTest): transaction = connection.begin() result = connection.execute("select * from query_users") - assert len(result.fetchall()) == 3 + l = result.fetchall() + assert len(l) == 3, "expected 3 got %d" % len(l) transaction.commit() def testrollback_off_conn(self): @@ -400,10 +600,11 @@ class TLTransactionTest(PersistTest): assert len(result.fetchall()) == 3 finally: external_connection.close() - + @testing.unsupported('sqlite') + @testing.exclude('mysql', '<', (5, 0, 3)) def testnesting(self): - """tests nesting of tranacstions""" + """tests nesting of transactions""" external_connection = tlengine.connect() self.assert_(external_connection.connection is not tlengine.contextual_connect().connection) tlengine.begin() @@ -420,8 +621,9 @@ class TLTransactionTest(PersistTest): finally: external_connection.close() + @testing.exclude('mysql', '<', (5, 0, 3)) def testmixednesting(self): - """tests nesting of transactions off the TLEngine directly inside of + """tests nesting of transactions off the TLEngine directly inside of tranasctions off the connection from the TLEngine""" external_connection = tlengine.connect() self.assert_(external_connection.connection is not tlengine.contextual_connect().connection) @@ -448,6 +650,7 @@ class TLTransactionTest(PersistTest): finally: external_connection.close() + @testing.exclude('mysql', '<', (5, 0, 3)) def testmoremixednesting(self): """tests nesting of transactions off the connection from the TLEngine inside of tranasctions off thbe TLEngine directly.""" @@ -471,6 +674,7 @@ class TLTransactionTest(PersistTest): finally: external_connection.close() + @testing.exclude('mysql', '<', (5, 0, 3)) def testsessionnesting(self): class User(object): pass @@ -486,7 +690,7 @@ class TLTransactionTest(PersistTest): finally: clear_mappers() - + def testconnections(self): """tests that contextual_connect is threadlocal""" c1 = tlengine.contextual_connect() @@ -495,7 +699,34 @@ class TLTransactionTest(PersistTest): c2.close() assert c1.connection.connection is not None -class ForUpdateTest(PersistTest): + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) + def testtwophasetransaction(self): + tlengine.begin_twophase() + tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.prepare() + tlengine.commit() + + tlengine.begin_twophase() + tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.commit() + + tlengine.begin_twophase() + tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.rollback() + + tlengine.begin_twophase() + tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.prepare() + tlengine.rollback() + + self.assertEquals( + tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(2,)] + ) + +class ForUpdateTest(TestBase): def setUpAll(self): global counters, metadata metadata = MetaData() @@ -504,17 +735,17 @@ class ForUpdateTest(PersistTest): Column('counter_value', INT), test_needs_acid=True, ) - counters.create(testbase.db) + counters.create(testing.db) def tearDown(self): - testbase.db.connect().execute(counters.delete()) + testing.db.connect().execute(counters.delete()) def tearDownAll(self): - counters.drop(testbase.db) + counters.drop(testing.db) def increment(self, count, errors, update_style=True, delay=0.005): - con = testbase.db.connect() + con = testing.db.connect() sel = counters.select(for_update=update_style, whereclause=counters.c.counter_id==1) - + for i in xrange(count): trans = con.begin() try: @@ -535,10 +766,10 @@ class ForUpdateTest(PersistTest): trans.rollback() errors.append(e) break - con.close() - @testing.supported('mysql', 'oracle', 'postgres') + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access') + def testqueued_update(self): """Test SELECT FOR UPDATE with concurrent modifications. @@ -546,7 +777,7 @@ class ForUpdateTest(PersistTest): with each mutator trying to increment a value stored in user_name. """ - db = testbase.db + db = testing.db db.execute(counters.insert(), counter_id=1, counter_value=0) iterations, thread_count = 10, 5 @@ -572,8 +803,8 @@ class ForUpdateTest(PersistTest): def overlap(self, ids, errors, update_style): sel = counters.select(for_update=update_style, - whereclause=counters.c.counter_id.in_(*ids)) - con = testbase.db.connect() + whereclause=counters.c.counter_id.in_(ids)) + con = testing.db.connect() trans = con.begin() try: rows = con.execute(sel).fetchall() @@ -582,9 +813,10 @@ class ForUpdateTest(PersistTest): except Exception, e: trans.rollback() errors.append(e) + con.close() def _threaded_overlap(self, thread_count, groups, update_style=True, pool=5): - db = testbase.db + db = testing.db for cid in range(pool - 1): db.execute(counters.insert(), counter_id=cid + 1, counter_value=0) @@ -598,8 +830,8 @@ class ForUpdateTest(PersistTest): thread.join() return errors - - @testing.supported('mysql', 'oracle', 'postgres') + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access') def testqueued_select(self): """Simple SELECT FOR UPDATE conflict test""" @@ -608,13 +840,15 @@ class ForUpdateTest(PersistTest): sys.stderr.write("Failure: %s\n" % e) self.assert_(len(errors) == 0) - @testing.supported('oracle', 'postgres') + @testing.unsupported('sqlite', 'mysql', 'mssql', 'firebird', + 'sybase', 'access') def testnowait_select(self): """Simple SELECT FOR UPDATE NOWAIT conflict test""" errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)], update_style='nowait') self.assert_(len(errors) != 0) - + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py index e28c72cd73..fa112c3b34 100644 --- a/test/ext/activemapper.py +++ b/test/ext/activemapper.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() from datetime import datetime from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore @@ -10,25 +10,25 @@ import sqlalchemy from testlib import * -class testcase(PersistTest): +class testcase(TestBase): def setUpAll(self): clear_mappers() objectstore.clear() global Person, Preferences, Address - + class Person(ActiveMapper): class mapping: __version_id_col__ = 'row_version' - full_name = column(String) - first_name = column(String) - middle_name = column(String) - last_name = column(String) + full_name = column(String(128)) + first_name = column(String(128)) + middle_name = column(String(128)) + last_name = column(String(128)) birth_date = column(DateTime) - ssn = column(String) - gender = column(String) - home_phone = column(String) - cell_phone = column(String) - work_phone = column(String) + ssn = column(String(128)) + gender = column(String(128)) + home_phone = column(String(128)) + cell_phone = column(String(128)) + work_phone = column(String(128)) row_version = column(Integer, default=0) prefs_id = column(Integer, foreign_key=ForeignKey('preferences.id')) addresses = one_to_many('Address', colname='person_id', backref='person', order_by=['state', 'city', 'postal_code']) @@ -49,34 +49,34 @@ class testcase(PersistTest): class Preferences(ActiveMapper): class mapping: __table__ = 'preferences' - favorite_color = column(String) - personality_type = column(String) + favorite_color = column(String(128)) + personality_type = column(String(128)) class Address(ActiveMapper): class mapping: - # note that in other objects, the 'id' primary key is + # note that in other objects, the 'id' primary key is # automatically added -- if you specify a primary key, # then ActiveMapper will not add an integer primary key # for you. id = column(Integer, primary_key=True) - type = column(String) - address_1 = column(String) - city = column(String) - state = column(String) - postal_code = column(String) + type = column(String(128)) + address_1 = column(String(128)) + city = column(String(128)) + state = column(String(128)) + postal_code = column(String(128)) person_id = column(Integer, foreign_key=ForeignKey('person.id')) - activemapper.metadata.connect(testbase.db) + activemapper.metadata.bind = testing.db activemapper.create_tables() - + def tearDownAll(self): clear_mappers() activemapper.drop_tables() - + def tearDown(self): for t in activemapper.metadata.table_iterator(reverse=True): t.delete().execute() - + def create_person_one(self): # create a person p1 = Person( @@ -102,8 +102,8 @@ class testcase(PersistTest): ] ) return p1 - - + + def create_person_two(self): p2 = Person( full_name='Lacey LaCour', @@ -126,19 +126,19 @@ class testcase(PersistTest): # a "self.preferences = Preferences()" into the __init__ # of Person also doens't seem to fix this p2.preferences = Preferences() - + return p2 - - + + def test_create(self): p1 = self.create_person_one() objectstore.flush() objectstore.clear() - - results = Person.query.select() - + + results = Person.query.all() + self.assertEquals(len(results), 1) - + person = results[0] self.assertEquals(person.id, p1.id) self.assertEquals(len(person.addresses), 2) @@ -149,72 +149,72 @@ class testcase(PersistTest): p1 = self.create_person_one() objectstore.flush() objectstore.clear() - - person = Person.query.select()[0] + + person = Person.query.first() person.gender = 'F' objectstore.flush() objectstore.clear() self.assertEquals(person.row_version, 2) - person = Person.query.select()[0] + person = Person.query.first() person.gender = 'M' objectstore.flush() objectstore.clear() self.assertEquals(person.row_version, 3) #TODO: check that a concurrent modification raises exception - p1 = Person.query.select()[0] - s1 = objectstore.session + p1 = Person.query.first() + s1 = objectstore() s2 = create_session() - objectstore.context.current = s2 - p2 = Person.query.select()[0] + objectstore.registry.set(s2) + p2 = Person.query.first() p1.first_name = "jack" p2.first_name = "ed" objectstore.flush() try: - objectstore.context.current = s1 + objectstore.registry.set(s1) objectstore.flush() # Only dialects with a sane rowcount can detect the ConcurrentModificationError - if testbase.db.dialect.supports_sane_rowcount(): + if testing.db.dialect.supports_sane_rowcount: assert False except exceptions.ConcurrentModificationError: pass - - + + def test_delete(self): p1 = self.create_person_one() - + objectstore.flush() objectstore.clear() - - results = Person.query.select() + + results = Person.query.all() self.assertEquals(len(results), 1) - - results[0].delete() + + objectstore.delete(results[0]) objectstore.flush() objectstore.clear() - - results = Person.query.select() + + results = Person.query.all() self.assertEquals(len(results), 0) - - + + def test_multiple(self): p1 = self.create_person_one() p2 = self.create_person_two() - + objectstore.flush() objectstore.clear() - + # select and make sure we get back two results - people = Person.query.select() + people = Person.query.all() self.assertEquals(len(people), 2) - + # make sure that our backwards relationships work self.assertEquals(people[0].addresses[0].person.id, p1.id) self.assertEquals(people[1].addresses[0].person.id, p2.id) - + # try a more complex select - results = Person.query.select( + results = Person.query.filter( or_( and_( Address.c.person_id == Person.c.id, @@ -225,10 +225,10 @@ class testcase(PersistTest): Preferences.c.favorite_color == 'Green' ) ) - ) + ).all() self.assertEquals(len(results), 2) - - + + def test_oneway_backref(self): # FIXME: I don't know why, but it seems that my backwards relationship # on preferences still ends up being a list even though I pass @@ -237,32 +237,33 @@ class testcase(PersistTest): # uses a function which I dont think existed when you first wrote ActiveMapper. p1 = self.create_person_one() self.assertEquals(p1.preferences.person, p1) - p1.delete() - + objectstore.flush() + objectstore.delete(p1) + objectstore.flush() objectstore.clear() - - + + def test_select_by(self): # FIXME: either I don't understand select_by, or it doesn't work. # FIXED (as good as we can for now): yup....everyone thinks it works that way....it only # generates joins for keyword arguments, not ColumnClause args. would need a new layer of # "MapperClause" objects to use properties in expressions. (MB) - + p1 = self.create_person_one() p2 = self.create_person_two() - + objectstore.flush() objectstore.clear() - - results = Person.query.join('addresses').select( - Address.c.postal_code.like('30075') - ) + + results = Person.query.join('addresses').filter( + Address.c.postal_code.like('30075') + ).all() self.assertEquals(len(results), 1) self.assertEquals(Person.query.count(), 2) -class testmanytomany(PersistTest): +class testmanytomany(TestBase): def setUpAll(self): clear_mappers() objectstore.clear() @@ -282,7 +283,7 @@ class testmanytomany(PersistTest): name = column(String(30)) foorel = many_to_many("foo", secondarytable, backref='bazrel') - activemapper.metadata.connect(testbase.db) + activemapper.metadata.bind = testing.db activemapper.create_tables() # Create a couple of activemapper objects @@ -300,9 +301,9 @@ class testmanytomany(PersistTest): objectstore.flush() objectstore.clear() - foo1 = foo.query.get_by(name='foo1') - baz1 = baz.query.get_by(name='baz1') - + foo1 = foo.query.filter_by(name='foo1').one() + baz1 = baz.query.filter_by(name='baz1').one() + # Just checking ... assert (foo1.name == 'foo1') assert (baz1.name == 'baz1') @@ -316,8 +317,8 @@ class testmanytomany(PersistTest): # baz1 to foo1.bazrel - (AttributeError: 'foo' object has no attribute 'bazrel') foo1.bazrel.append(baz1) assert (foo1.bazrel == [baz1]) - -class testselfreferential(PersistTest): + +class testselfreferential(TestBase): def setUpAll(self): clear_mappers() objectstore.clear() @@ -328,8 +329,8 @@ class testselfreferential(PersistTest): name = column(String(30)) parent_id = column(Integer, foreign_key=ForeignKey('treenode.id')) children = one_to_many('TreeNode', colname='id', backref='parent') - - activemapper.metadata.connect(testbase.db) + + activemapper.metadata.bind = testing.db activemapper.create_tables() def tearDownAll(self): clear_mappers() @@ -341,16 +342,16 @@ class testselfreferential(PersistTest): t.children.append(TreeNode(name='node3')) objectstore.flush() objectstore.clear() - - t = TreeNode.query.get_by(name='node1') + + t = TreeNode.query.filter_by(name='node1').one() assert (t.name == 'node1') assert (t.children[0].name == 'node2') assert (t.children[1].name == 'node3') assert (t.children[1].parent is t) objectstore.clear() - t = TreeNode.query.get_by(name='node3') - assert (t.parent is TreeNode.query.get_by(name='node1')) - + t = TreeNode.query.filter_by(name='node3').one() + assert (t.parent is TreeNode.query.filter_by(name='node1').one()) + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/ext/alltests.py b/test/ext/alltests.py index 589f0f68f2..d5db4d01ed 100644 --- a/test/ext/alltests.py +++ b/test/ext/alltests.py @@ -1,12 +1,17 @@ -import testbase -import unittest, doctest +import testenv; testenv.configure_for_tests() +import doctest, sys, unittest def suite(): unittest_modules = ['ext.activemapper', 'ext.assignmapper', + 'ext.declarative', 'ext.orderinglist', 'ext.associationproxy'] - doctest_modules = ['sqlalchemy.ext.sqlsoup'] + + if sys.version_info >= (2, 4): + doctest_modules = ['sqlalchemy.ext.sqlsoup'] + else: + doctest_modules = [] alltests = unittest.TestSuite() for name in unittest_modules: @@ -20,4 +25,4 @@ def suite(): if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/ext/assignmapper.py b/test/ext/assignmapper.py index 31b3dd576f..1cb2ca3751 100644 --- a/test/ext/assignmapper.py +++ b/test/ext/assignmapper.py @@ -1,34 +1,35 @@ -import testbase - +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper from sqlalchemy.ext.assignmapper import assign_mapper from sqlalchemy.ext.sessioncontext import SessionContext from testlib import * -class AssignMapperTest(PersistTest): +class AssignMapperTest(TestBase): def setUpAll(self): global metadata, table, table2 - metadata = MetaData(testbase.db) - table = Table('sometable', metadata, + metadata = MetaData(testing.db) + table = Table('sometable', metadata, Column('id', Integer, primary_key=True), Column('data', String(30))) - table2 = Table('someothertable', metadata, + table2 = Table('someothertable', metadata, Column('id', Integer, primary_key=True), Column('someid', None, ForeignKey('sometable.id')) ) metadata.create_all() + @testing.uses_deprecated('SessionContext', 'assign_mapper') def setUp(self): global SomeObject, SomeOtherObject, ctx class SomeObject(object):pass class SomeOtherObject(object):pass - + ctx = SessionContext(create_session) assign_mapper(ctx, SomeObject, table, properties={ 'options':relation(SomeOtherObject) - }) + }) assign_mapper(ctx, SomeOtherObject, table2) s = SomeObject() @@ -41,20 +42,22 @@ class AssignMapperTest(PersistTest): def tearDownAll(self): metadata.drop_all() + def tearDown(self): for table in metadata.table_iterator(reverse=True): table.delete().execute() clear_mappers() - + + @testing.uses_deprecated('assign_mapper') def test_override_attributes(self): - + sso = SomeOtherObject.query().first() - + assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id s2 = SomeObject(someid=12) s3 = SomeOtherObject(someid=123, bogus=345) - + class ValidatedOtherObject(object):pass assign_mapper(ctx, ValidatedOtherObject, table2, validate=True) @@ -64,16 +67,17 @@ class AssignMapperTest(PersistTest): assert False except exceptions.ArgumentError: pass - + + @testing.uses_deprecated('assign_mapper') def test_dont_clobber_methods(self): class MyClass(object): def expunge(self): return "an expunge !" - + assign_mapper(ctx, MyClass, table2) - + assert MyClass().expunge() == "an expunge !" - + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py index f602871c2c..8837b4d04c 100644 --- a/test/ext/associationproxy.py +++ b/test/ext/associationproxy.py @@ -1,5 +1,5 @@ -import testbase - +import testenv; testenv.configure_for_tests() +import gc from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.collections import collection @@ -33,25 +33,25 @@ class ObjectCollection(object): def __iter__(self): return iter(self.values) -class _CollectionOperations(PersistTest): +class _CollectionOperations(TestBase): def setUp(self): collection_class = self.collection_class - metadata = MetaData(testbase.db) - + metadata = MetaData(testing.db) + parents_table = Table('Parent', metadata, Column('id', Integer, primary_key=True), - Column('name', String)) + Column('name', String(128))) children_table = Table('Children', metadata, Column('id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('Parent.id')), - Column('foo', String), - Column('name', String)) + Column('foo', String(128)), + Column('name', String(128))) class Parent(object): children = association_proxy('_children', 'name') - + def __init__(self, name): self.name = name @@ -79,7 +79,8 @@ class _CollectionOperations(PersistTest): self.metadata.drop_all() def roundtrip(self, obj): - self.session.save(obj) + if obj not in self.session: + self.session.save(obj) self.session.flush() id, type_ = obj.id, type(obj) self.session.clear() @@ -87,7 +88,7 @@ class _CollectionOperations(PersistTest): def _test_sequence_ops(self): Parent, Child = self.Parent, self.Child - + p1 = Parent('P1') self.assert_(not p1._children) @@ -113,7 +114,7 @@ class _CollectionOperations(PersistTest): self.assert_(p1._children[0].name == 'regular') self.assert_(p1._children[1].name == 'proxied') - + del p1._children[1] self.assert_(len(p1._children) == 1) @@ -124,7 +125,7 @@ class _CollectionOperations(PersistTest): self.assert_(len(p1._children) == 0) self.assert_(len(p1.children) == 0) - + p1.children = ['a','b','c'] self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) @@ -151,7 +152,7 @@ class _CollectionOperations(PersistTest): p1.children.append('changed-in-place') self.assert_(p1.children.count('changed-in-place') == 2) - + p1.children.remove('changed-in-place') self.assert_(p1.children.count('changed-in-place') == 1) @@ -185,7 +186,53 @@ class _CollectionOperations(PersistTest): after = ['a', 'b', 'O', 'z', 'O', 'z', 'O', 'h', 'O', 'j'] self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - + + self.assertRaises(TypeError, set, [p1.children]) + + p1.children *= 0 + after = [] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children += ['a', 'b'] + after = ['a', 'b'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children += ['c'] + after = ['a', 'b', 'c'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children *= 1 + after = ['a', 'b', 'c'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children *= 2 + after = ['a', 'b', 'c', 'a', 'b', 'c'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children = ['a'] + after = ['a'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + self.assert_((p1.children * 2) == ['a', 'a']) + self.assert_((2 * p1.children) == ['a', 'a']) + self.assert_((p1.children * 0) == []) + self.assert_((0 * p1.children) == []) + + self.assert_((p1.children + ['b']) == ['a', 'b']) + self.assert_((['b'] + p1.children) == ['b', 'a']) + + try: + p1.children + 123 + assert False + except TypeError: + assert True + class DefaultTest(_CollectionOperations): def __init__(self, *args, **kw): super(DefaultTest, self).__init__(*args, **kw) @@ -194,6 +241,7 @@ class DefaultTest(_CollectionOperations): def test_sequence_ops(self): self._test_sequence_ops() + class ListTest(_CollectionOperations): def __init__(self, *args, **kw): super(ListTest, self).__init__(*args, **kw) @@ -247,7 +295,7 @@ class CustomDictTest(DictTest): self.assert_(p1._children['a'].name == 'regular') self.assert_(p1._children['b'].name == 'proxied') - + del p1._children['b'] self.assert_(len(p1._children) == 1) @@ -270,7 +318,7 @@ class CustomDictTest(DictTest): self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) - p1.children['e'] = 'changed-in-place' + p1.children['e'] = 'changed-in-place' self.assert_(p1.children['e'] == 'changed-in-place') inplace_id = p1._children['e'].id p1 = self.roundtrip(p1) @@ -279,19 +327,22 @@ class CustomDictTest(DictTest): p1._children = {} self.assert_(len(p1.children) == 0) - + try: p1._children = [] self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(True) try: p1._children = None self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(True) + self.assertRaises(TypeError, set, [p1.children]) + + class SetTest(_CollectionOperations): def __init__(self, *args, **kw): super(SetTest, self).__init__(*args, **kw) @@ -342,7 +393,7 @@ class SetTest(_CollectionOperations): self.assert_(len(p1._children) == 0) self.assert_(len(p1.children) == 0) - + p1.children = ['a','b','c'] self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) @@ -377,11 +428,11 @@ class SetTest(_CollectionOperations): p1 = self.roundtrip(p1) self.assert_(len(p1.children) == 2) self.assert_(popped not in p1.children) - + p1.children = ['a','b','c'] p1 = self.roundtrip(p1) self.assert_(p1.children == set(['a','b','c'])) - + p1.children.discard('b') p1 = self.roundtrip(p1) self.assert_(p1.children == set(['a', 'c'])) @@ -396,15 +447,17 @@ class SetTest(_CollectionOperations): try: p1._children = [] self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(True) try: p1._children = None self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(True) + self.assertRaises(TypeError, set, [p1.children]) + def test_set_comparisons(self): Parent, Child = self.Parent, self.Child @@ -432,7 +485,7 @@ class SetTest(_CollectionOperations): control.issubset(other)) self.assertEqual(p1.children.issuperset(other), control.issuperset(other)) - + self.assert_((p1.children == other) == (control == other)) self.assert_((p1.children != other) == (control != other)) self.assert_((p1.children < other) == (control < other)) @@ -475,6 +528,39 @@ class SetTest(_CollectionOperations): print 'got', repr(p.children) raise + # in-place mutations + for op in ('|=', '-=', '&=', '^='): + for base in (['a', 'b', 'c'], []): + for other in (set(['a','b','c']), set(['a','b','c','d']), + set(['a']), set(['a','b']), + set(['c','d']), set(['e', 'f', 'g']), + frozenset(['e', 'f', 'g']), + set()): + p = Parent('p') + p.children = base[:] + control = set(base[:]) + + exec "p.children %s other" % op + exec "control %s other" % op + + try: + self.assert_(p.children == control) + except: + print 'Test %s %s %s:' % (set(base), op, other) + print 'want', repr(control) + print 'got', repr(p.children) + raise + + p = self.roundtrip(p) + + try: + self.assert_(p.children == control) + except: + print 'Test %s %s %s:' % (base, op, other) + print 'want', repr(control) + print 'got', repr(p.children) + raise + class CustomSetTest(SetTest): def __init__(self, *args, **kw): @@ -506,20 +592,20 @@ class CustomObjectTest(_CollectionOperations): except TypeError: pass -class ScalarTest(PersistTest): +class ScalarTest(TestBase): def test_scalar_proxy(self): - metadata = MetaData(testbase.db) - + metadata = MetaData(testing.db) + parents_table = Table('Parent', metadata, Column('id', Integer, primary_key=True), - Column('name', String)) + Column('name', String(128))) children_table = Table('Children', metadata, Column('id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('Parent.id')), - Column('foo', String), - Column('bar', String), - Column('baz', String)) + Column('foo', String(128)), + Column('bar', String(128)), + Column('baz', String(128))) class Parent(object): foo = association_proxy('child', 'foo') @@ -527,7 +613,7 @@ class ScalarTest(PersistTest): creator=lambda v: Child(bar=v)) baz = association_proxy('child', 'baz', creator=lambda v: Child(baz=v)) - + def __init__(self, name): self.name = name @@ -545,12 +631,13 @@ class ScalarTest(PersistTest): session = create_session() def roundtrip(obj): - session.save(obj) + if obj not in session: + session.save(obj) session.flush() id, type_ = obj.id, type(obj) session.clear() return session.query(type_).get(id) - + p = Parent('p') # No child @@ -570,7 +657,7 @@ class ScalarTest(PersistTest): self.assert_(p.foo == 'a') self.assert_(p.bar == 'x') self.assert_(p.baz == 'c') - + p = roundtrip(p) self.assert_(p.foo == 'a') @@ -620,25 +707,25 @@ class ScalarTest(PersistTest): # Ensure an immediate __set__ works. p2 = Parent('p2') p2.bar = 'quux' - -class LazyLoadTest(PersistTest): + +class LazyLoadTest(TestBase): def setUp(self): - metadata = MetaData(testbase.db) - + metadata = MetaData(testing.db) + parents_table = Table('Parent', metadata, Column('id', Integer, primary_key=True), - Column('name', String)) + Column('name', String(128))) children_table = Table('Children', metadata, Column('id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('Parent.id')), - Column('foo', String), - Column('name', String)) + Column('foo', String(128)), + Column('name', String(128))) class Parent(object): children = association_proxy('_children', 'name') - + def __init__(self, name): self.name = name @@ -727,7 +814,66 @@ class LazyLoadTest(PersistTest): self.assert_('_children' in p.__dict__) self.assert_(p._children is not None) - + + +class ReconstitutionTest(TestBase): + def setUp(self): + metadata = MetaData(testing.db) + parents = Table('parents', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30))) + children = Table('children', metadata, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('parents.id')), + Column('name', String(30))) + metadata.create_all() + parents.insert().execute(name='p1') + + class Parent(object): + kids = association_proxy('children', 'name') + def __init__(self, name): + self.name = name + + class Child(object): + def __init__(self, name): + self.name = name + + mapper(Parent, parents, properties=dict(children=relation(Child))) + mapper(Child, children) + + self.metadata = metadata + self.Parent = Parent + + def tearDown(self): + self.metadata.drop_all() + + def test_weak_identity_map(self): + session = create_session(weak_identity_map=True) + + def add_child(parent_name, child_name): + parent = (session.query(self.Parent). + filter_by(name=parent_name)).one() + parent.kids.append(child_name) + + + add_child('p1', 'c1') + gc.collect() + add_child('p1', 'c2') + + session.flush() + p = session.query(self.Parent).filter_by(name='p1').one() + assert set(p.kids) == set(['c1', 'c2']), p.kids + + def test_copy(self): + import copy + p = self.Parent('p1') + p.kids.extend(['c1', 'c2']) + p_copy = copy.copy(p) + del p + gc.collect() + + assert set(p_copy.kids) == set(['c1', 'c2']), p.kids + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/ext/declarative.py b/test/ext/declarative.py new file mode 100644 index 0000000000..ab07627dda --- /dev/null +++ b/test/ext/declarative.py @@ -0,0 +1,797 @@ +import testenv; testenv.configure_for_tests() + +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.orm.interfaces import MapperExtension +from sqlalchemy.ext.declarative import declarative_base, declared_synonym, \ + synonym_for, comparable_using +from sqlalchemy import exceptions +from testlib.fixtures import Base as Fixture +from testlib import * + + +class DeclarativeTest(TestBase, AssertsExecutionResults): + def setUp(self): + global Base + Base = declarative_base(testing.db) + + def tearDown(self): + Base.metadata.drop_all() + + def test_basic(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") + + class Address(Base, Fixture): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + email = Column(String(50), key='_email') + user_id = Column('user_id', Integer, ForeignKey('users.id'), + key='_user_id') + + Base.metadata.create_all() + + assert Address.__table__.c['id'].name == 'id' + assert Address.__table__.c['_email'].name == 'email' + assert Address.__table__.c['_user_id'].name == 'user_id' + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email=='two').one() + self.assertEquals(a1, Address(email='two')) + self.assertEquals(a1.user, User(name='u1')) + + def test_recompile_on_othermapper(self): + """declarative version of the same test in mappers.py""" + + from sqlalchemy.orm import mapperlib + + class User(Base): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + + class Address(Base): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + user = relation("User", primaryjoin=user_id==User.id, backref="addresses") + + assert mapperlib._new_mappers is True + u = User() + assert User.addresses + assert mapperlib._new_mappers is False + + def test_nice_dependency_error(self): + class User(Base): + __tablename__ = 'users' + id = Column('id', Integer, primary_key=True) + addresses = relation("Address") + + def go(): + class Address(Base): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + foo = column_property(User.id==5) + self.assertRaises(exceptions.InvalidRequestError, go) + + def test_add_prop(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + User.name = Column('name', String(50)) + User.addresses = relation("Address", backref="user") + + class Address(Base, Fixture): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + Address.email = Column(String(50), key='_email') + Address.user_id = Column('user_id', Integer, ForeignKey('users.id'), + key='_user_id') + + Base.metadata.create_all() + + assert Address.__table__.c['id'].name == 'id' + assert Address.__table__.c['_email'].name == 'email' + assert Address.__table__.c['_user_id'].name == 'user_id' + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email=='two').one() + self.assertEquals(a1, Address(email='two')) + self.assertEquals(a1.user, User(name='u1')) + + + def test_custom_mapper(self): + class MyExt(MapperExtension): + def create_instance(self): + return "CHECK" + + def mymapper(cls, tbl, **kwargs): + kwargs['extension'] = MyExt() + return mapper(cls, tbl, **kwargs) + + from sqlalchemy.orm.mapper import Mapper + class MyMapper(Mapper): + def __init__(self, *args, **kwargs): + kwargs['extension'] = MyExt() + Mapper.__init__(self, *args, **kwargs) + + from sqlalchemy.orm import scoping + ss = scoping.ScopedSession(create_session) + ss.extension = MyExt() + ss_mapper = ss.mapper + + for mapperfunc in (mymapper, MyMapper, ss_mapper): + base = declarative_base() + class Foo(base): + __tablename__ = 'foo' + __mapper_cls__ = mapperfunc + id = Column(Integer, primary_key=True) + assert Foo.__mapper__.compile().extension.create_instance() == 'CHECK' + + base = declarative_base(mapper=mapperfunc) + class Foo(base): + __tablename__ = 'foo' + id = Column(Integer, primary_key=True) + assert Foo.__mapper__.compile().extension.create_instance() == 'CHECK' + + + @testing.emits_warning('Ignoring declarative-like tuple value of ' + 'attribute id') + def test_oops(self): + def define(): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True), + name = Column('name', String(50)) + assert False + self.assertRaisesMessage( + exceptions.ArgumentError, + "Mapper Mapper|User|users could not assemble any primary key", + define) + + def test_expression(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") + + class Address(Base, Fixture): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + + User.address_count = column_property(select([func.count(Address.id)]).where(Address.user_id==User.id).as_scalar()) + + Base.metadata.create_all() + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', address_count=2, addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + def test_column(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + + User.a = Column('a', String(10)) + User.b = Column(String(10)) + + Base.metadata.create_all() + + u1 = User(name='u1', a='a', b='b') + assert u1.a == 'a' + assert User.a.get_history(u1) == (['a'], [], []) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), + [User(name='u1', a='a', b='b')]) + + def test_column_properties(self): + + class Address(Base, Fixture): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + adr_count = column_property(select([func.count(Address.id)], Address.user_id==id).as_scalar()) + addresses = relation(Address) + + Base.metadata.create_all() + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', adr_count=2, addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + def test_column_properties_2(self): + + class Address(Base, Fixture): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + # this is not "valid" but we want to test that Address.id doesnt get stuck into user's table + adr_count = Address.id + + self.assertEquals(set(User.__table__.c.keys()), set(['id', 'name'])) + self.assertEquals(set(Address.__table__.c.keys()), set(['id', 'email', 'user_id'])) + + def test_deferred(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = deferred(Column(String(50))) + + Base.metadata.create_all() + sess = create_session() + sess.save(User(name='u1')) + sess.flush() + sess.clear() + + u1 = sess.query(User).filter(User.name=='u1').one() + assert 'name' not in u1.__dict__ + def go(): + assert u1.name == 'u1' + self.assert_sql_count(testing.db, go, 1) + + def test_synonym_inline(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + _name = Column('name', String(50)) + def _set_name(self, name): + self._name = "SOMENAME " + name + def _get_name(self): + return self._name + name = synonym('_name', descriptor=property(_get_name, _set_name)) + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + assert u1.name == "SOMENAME someuser", u1.name + sess.save(u1) + sess.flush() + self.assertEquals(sess.query(User).filter(User.name=="SOMENAME someuser").one(), u1) + + @testing.uses_deprecated('Call to deprecated function declared_synonym') + def test_decl_synonym_inline(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + _name = Column('name', String(50)) + def _set_name(self, name): + self._name = "SOMENAME " + name + def _get_name(self): + return self._name + name = declared_synonym(property(_get_name, _set_name), '_name') + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + assert u1.name == "SOMENAME someuser", u1.name + sess.save(u1) + sess.flush() + self.assertEquals(sess.query(User).filter(User.name=="SOMENAME someuser").one(), u1) + + def test_synonym_added(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + _name = Column('name', String(50)) + def _set_name(self, name): + self._name = "SOMENAME " + name + def _get_name(self): + return self._name + name = property(_get_name, _set_name) + User.name = synonym('_name', descriptor=User.name) + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + assert u1.name == "SOMENAME someuser", u1.name + sess.save(u1) + sess.flush() + self.assertEquals(sess.query(User).filter(User.name=="SOMENAME someuser").one(), u1) + + @testing.uses_deprecated('Call to deprecated function declared_synonym') + def test_decl_synonym_added(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + _name = Column('name', String(50)) + def _set_name(self, name): + self._name = "SOMENAME " + name + def _get_name(self): + return self._name + name = property(_get_name, _set_name) + User.name = declared_synonym(User.name, '_name') + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + assert u1.name == "SOMENAME someuser", u1.name + sess.save(u1) + sess.flush() + self.assertEquals(sess.query(User).filter(User.name=="SOMENAME someuser").one(), u1) + + def test_joined_inheritance(self): + class Company(Base, Fixture): + __tablename__ = 'companies' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + employees = relation("Person") + + class Person(Base, Fixture): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + company_id = Column('company_id', Integer, ForeignKey('companies.id')) + name = Column('name', String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity':'engineer'} + id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column('primary_language', String(50)) + + class Manager(Person): + __tablename__ = 'managers' + __mapper_args__ = {'polymorphic_identity':'manager'} + id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) + golf_swing = Column('golf_swing', String(50)) + + Base.metadata.create_all() + + sess = create_session() + c1 = Company(name="MegaCorp, Inc.", employees=[ + Engineer(name="dilbert", primary_language="java"), + Engineer(name="wally", primary_language="c++"), + Manager(name="dogbert", golf_swing="fore!") + ]) + + c2 = Company(name="Elbonia, Inc.", employees=[ + Engineer(name="vlad", primary_language="cobol") + ]) + + sess.save(c1) + sess.save(c2) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).first(), c2) + + def test_inheritance_with_undefined_relation(self): + class Parent(Base): + __tablename__ = 'parent' + id = Column('id', Integer, primary_key=True) + tp = Column('type', String(50)) + __mapper_args__ = dict(polymorphic_on = tp) + + + class Child1(Parent): + __tablename__ = 'child1' + id = Column('id', Integer, ForeignKey('parent.id'), primary_key=True) + related_child2 = Column('c2', Integer, ForeignKey('child2.id')) + __mapper_args__ = dict(polymorphic_identity = 'child1') + + # no exception is raised by the ForeignKey to "child2" even though child2 doesn't exist yet + + class Child2(Parent): + __tablename__ = 'child2' + id = Column('id', Integer, ForeignKey('parent.id'), primary_key=True) + related_child1 = Column('c1', Integer) + __mapper_args__ = dict(polymorphic_identity = 'child2') + + compile_mappers() # no exceptions here + + def test_reentrant_compile_via_foreignkey(self): + + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") + + class Address(Base, Fixture): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey(User.id)) + + compile_mappers() # this forces a re-entrant compile() due to the User.id within the ForeignKey + + Base.metadata.create_all() + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + def test_relation_reference(self): + class Address(Base, Fixture): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user", + primaryjoin=id==Address.user_id) + + User.address_count = column_property(select([func.count(Address.id)]).where(Address.user_id==User.id).as_scalar()) + + Base.metadata.create_all() + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', address_count=2, addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + def test_single_inheritance(self): + class Company(Base, Fixture): + __tablename__ = 'companies' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + employees = relation("Person") + + class Person(Base, Fixture): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + company_id = Column('company_id', Integer, ForeignKey('companies.id')) + name = Column('name', String(50)) + discriminator = Column('type', String(50)) + primary_language = Column('primary_language', String(50)) + golf_swing = Column('golf_swing', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity':'engineer'} + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity':'manager'} + + Base.metadata.create_all() + + sess = create_session() + c1 = Company(name="MegaCorp, Inc.", employees=[ + Engineer(name="dilbert", primary_language="java"), + Engineer(name="wally", primary_language="c++"), + Manager(name="dogbert", golf_swing="fore!") + ]) + + c2 = Company(name="Elbonia, Inc.", employees=[ + Engineer(name="vlad", primary_language="cobol") + ]) + + sess.save(c1) + sess.save(c2) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(Person).filter(Engineer.primary_language=='cobol').first(), Engineer(name='vlad')) + self.assertEquals(sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).first(), c2) + + def test_with_explicit_autoloaded(self): + meta = MetaData(testing.db) + t1 = Table('t1', meta, Column('id', String(50), primary_key=True), Column('data', String(50))) + meta.create_all() + try: + class MyObj(Base): + __table__ = Table('t1', Base.metadata, autoload=True) + + sess = create_session() + m = MyObj(id="someid", data="somedata") + sess.save(m) + sess.flush() + + assert t1.select().execute().fetchall() == [('someid', 'somedata')] + finally: + meta.drop_all() + + +class DeclarativeReflectionTest(TestBase): + def setUpAll(self): + global reflection_metadata + reflection_metadata = MetaData(testing.db) + + Table('users', reflection_metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + test_needs_fk=True) + Table('addresses', reflection_metadata, + Column('id', Integer, primary_key=True), + Column('email', String(50)), + Column('user_id', Integer, ForeignKey('users.id')), + test_needs_fk=True) + Table('imhandles', reflection_metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer), + Column('network', String(50)), + Column('handle', String(50)), + test_needs_fk=True) + + reflection_metadata.create_all() + + def setUp(self): + global Base + Base = declarative_base(testing.db) + + def tearDown(self): + for t in reflection_metadata.table_iterator(): + t.delete().execute() + + def tearDownAll(self): + reflection_metadata.drop_all() + + def test_basic(self): + meta = MetaData(testing.db) + + class User(Base, Fixture): + __tablename__ = 'users' + __autoload__ = True + addresses = relation("Address", backref="user") + + class Address(Base, Fixture): + __tablename__ = 'addresses' + __autoload__ = True + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email=='two').one() + self.assertEquals(a1, Address(email='two')) + self.assertEquals(a1.user, User(name='u1')) + + def test_rekey(self): + meta = MetaData(testing.db) + + class User(Base, Fixture): + __tablename__ = 'users' + __autoload__ = True + nom = Column('name', String(50), key='nom') + addresses = relation("Address", backref="user") + + class Address(Base, Fixture): + __tablename__ = 'addresses' + __autoload__ = True + + u1 = User(nom='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(nom='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email=='two').one() + self.assertEquals(a1, Address(email='two')) + self.assertEquals(a1.user, User(nom='u1')) + + self.assertRaises(TypeError, User, name='u3') + + def test_supplied_fk(self): + meta = MetaData(testing.db) + + class IMHandle(Base, Fixture): + __tablename__ = 'imhandles' + __autoload__ = True + + user_id = Column('user_id', Integer, + ForeignKey('users.id')) + class User(Base, Fixture): + __tablename__ = 'users' + __autoload__ = True + handles = relation("IMHandle", backref="user") + + u1 = User(name='u1', handles=[ + IMHandle(network='blabber', handle='foo'), + IMHandle(network='lol', handle='zomg') + ]) + sess = create_session() + sess.save(u1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).all(), [User(name='u1', handles=[ + IMHandle(network='blabber', handle='foo'), + IMHandle(network='lol', handle='zomg') + ])]) + + a1 = sess.query(IMHandle).filter(IMHandle.handle=='zomg').one() + self.assertEquals(a1, IMHandle(network='lol', handle='zomg')) + self.assertEquals(a1.user, User(name='u1')) + + def test_synonym_for(self): + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + + @synonym_for('name') + @property + def namesyn(self): + return self.name + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + assert u1.name == "someuser", u1.name + assert u1.namesyn == 'someuser', u1.namesyn + sess.save(u1) + sess.flush() + + rt = sess.query(User).filter(User.namesyn=='someuser').one() + self.assertEquals(rt, u1) + + def test_comparable_using(self): + class NameComparator(PropComparator): + @property + def upperself(self): + cls = self.prop.parent.class_ + col = getattr(cls, 'name') + return func.upper(col) + + def operate(self, op, other, **kw): + return op(self.upperself, other, **kw) + + class User(Base, Fixture): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + + @comparable_using(NameComparator) + @property + def uc_name(self): + return self.name is not None and self.name.upper() or None + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + assert u1.name == "someuser", u1.name + assert u1.uc_name == 'SOMEUSER', u1.uc_name + sess.save(u1) + sess.flush() + sess.clear() + + rt = sess.query(User).filter(User.uc_name=='SOMEUSER').one() + self.assertEquals(rt, u1) + sess.clear() + + rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one() + self.assertEquals(rt, u1) + +if __name__ == '__main__': + testing.main() diff --git a/test/ext/orderinglist.py b/test/ext/orderinglist.py index d16e20da73..ff27a63753 100644 --- a/test/ext/orderinglist.py +++ b/test/ext/orderinglist.py @@ -1,13 +1,13 @@ -import testbase - +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.orderinglist import * from testlib import * + metadata = None -# order in whole steps +# order in whole steps def step_numbering(step): def f(index, collection): return step * index @@ -37,7 +37,7 @@ def alpha_ordering(index, collection): s += chr(index + 65) return s -class OrderingListTest(PersistTest): +class OrderingListTest(TestBase): def setUp(self): global metadata, slides_table, bullets_table, Slide, Bullet slides_table, bullets_table = None, None @@ -51,16 +51,16 @@ class OrderingListTest(PersistTest): global metadata, slides_table, bullets_table, Slide, Bullet - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) slides_table = Table('test_Slides', metadata, Column('id', Integer, primary_key=True), - Column('name', String)) + Column('name', String(128))) bullets_table = Table('test_Bullets', metadata, Column('id', Integer, primary_key=True), Column('slide_id', Integer, ForeignKey('test_Slides.id')), Column('position', Integer), - Column('text', String)) + Column('text', String(128))) class Slide(object): def __init__(self, name): @@ -140,7 +140,7 @@ class OrderingListTest(PersistTest): self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 4) - + titles = ['s1/b1','s1/b2','s1/b100','s1/b4'] found = [b.text for b in srt.bullets] @@ -174,7 +174,7 @@ class OrderingListTest(PersistTest): self.assert_(s1.bullets[0].position == 1) self.assert_(s1.bullets[1].position == 2) self.assert_(s1.bullets[2].position == 3) - + s1.bullets.append(Bullet('s1/b4')) self.assert_(s1.bullets[0].position == 1) self.assert_(s1.bullets[1].position == 2) @@ -204,7 +204,7 @@ class OrderingListTest(PersistTest): found = [b.text for b in srt.bullets] self.assert_(titles == found) - + def test_insert(self): self._setup(ordering_list('position')) @@ -218,7 +218,7 @@ class OrderingListTest(PersistTest): self.assert_(s1.bullets[1].position == 1) self.assert_(s1.bullets[2].position == 2) self.assert_(s1.bullets[3].position == 3) - + s1.bullets.insert(2, Bullet('insert_at_2')) self.assert_(s1.bullets[0].position == 0) self.assert_(s1.bullets[1].position == 1) @@ -247,7 +247,7 @@ class OrderingListTest(PersistTest): self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 6) - + texts = ['1','2','insert_at_2','3','4','999'] found = [b.text for b in srt.bullets] @@ -290,7 +290,7 @@ class OrderingListTest(PersistTest): self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 3) - + texts = ['1', '6', '3'] for i, text in enumerate(texts): self.assert_(srt.bullets[i].position == i) @@ -325,13 +325,13 @@ class OrderingListTest(PersistTest): session.clear() srt = session.query(Slide).get(id) - + self.assert_(srt.bullets) self.assert_(len(srt.bullets) == 3) self.assert_(srt.bullets[1].text == 'new 2') self.assert_(srt.bullets[2].text == '3') - + def test_funky_ordering(self): class Pos(object): def __init__(self): @@ -365,7 +365,7 @@ class OrderingListTest(PersistTest): fibbed.insert(2, Pos()) fibbed.insert(4, Pos()) fibbed.insert(6, Pos()) - + for li, pos in (0,1), (1,2), (2,3), (3,5), (4,8), (5,13), (6,21), (7,34): self.assert_(fibbed[li].position == pos) @@ -381,5 +381,6 @@ class OrderingListTest(PersistTest): for li, pos in (0,'A'), (1,'B'), (2,'C'), (3,'D'): self.assert_(alpha[li].position == pos) + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/alltests.py b/test/orm/alltests.py index 4f8f4b6b7a..73406c00d5 100644 --- a/test/orm/alltests.py +++ b/test/orm/alltests.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest import inheritance.alltests as inheritance @@ -11,26 +11,31 @@ def suite(): 'orm.lazy_relations', 'orm.eager_relations', 'orm.mapper', + 'orm.expire', + 'orm.selectable', 'orm.collection', 'orm.generative', 'orm.lazytest1', 'orm.assorted_eager', - - 'orm.sessioncontext', + + 'orm.naturalpks', + 'orm.sessioncontext', 'orm.unitofwork', 'orm.session', 'orm.cascade', 'orm.relationships', 'orm.association', 'orm.merge', + 'orm.pickled', 'orm.memusage', - + 'orm.cycles', 'orm.entity', 'orm.compile', 'orm.manytomany', 'orm.onetoone', + 'orm.dynamic', ) alltests = unittest.TestSuite() for name in modules_to_test: @@ -44,4 +49,4 @@ def suite(): if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/orm/association.py b/test/orm/association.py index a2b8994188..65d7025383 100644 --- a/test/orm/association.py +++ b/test/orm/association.py @@ -1,14 +1,15 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * -class AssociationTest(PersistTest): +class AssociationTest(TestBase): + @testing.uses_deprecated('association option') def setUpAll(self): global items, item_keywords, keywords, metadata, Item, Keyword, KeywordAssociation - metadata = MetaData(testbase.db) - items = Table('items', metadata, + metadata = MetaData(testing.db) + items = Table('items', metadata, Column('item_id', Integer, primary_key=True), Column('name', String(40)), ) @@ -22,7 +23,7 @@ class AssociationTest(PersistTest): Column('name', String(40)) ) metadata.create_all() - + class Item(object): def __init__(self, name): self.name = name @@ -39,7 +40,7 @@ class AssociationTest(PersistTest): self.data = data def __repr__(self): return "KeywordAssociation itemid=%d keyword=%s data=%s" % (self.item_id, repr(self.keyword), self.data) - + mapper(Keyword, keywords) mapper(KeywordAssociation, item_keywords, properties={ 'keyword':relation(Keyword, lazy=False) @@ -47,14 +48,14 @@ class AssociationTest(PersistTest): mapper(Item, items, properties={ 'keywords' : relation(KeywordAssociation, association=Keyword) }) - + def tearDown(self): for t in metadata.table_iterator(reverse=True): t.delete().execute() def tearDownAll(self): clear_mappers() metadata.drop_all() - + def testinsert(self): sess = create_session() item1 = Item('item1') @@ -67,7 +68,7 @@ class AssociationTest(PersistTest): sess.flush() saved = repr([item1, item2]) sess.clear() - l = sess.query(Item).select() + l = sess.query(Item).all() loaded = repr(l) print saved print loaded @@ -80,14 +81,14 @@ class AssociationTest(PersistTest): item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) sess.save(item1) sess.flush() - + red_keyword = item1.keywords[1].keyword del item1.keywords[1] item1.keywords.append(KeywordAssociation(red_keyword, 'new_red_assoc')) sess.flush() saved = repr([item1]) sess.clear() - l = sess.query(Item).select() + l = sess.query(Item).all() loaded = repr(l) print saved print loaded @@ -103,7 +104,7 @@ class AssociationTest(PersistTest): sess.save(item1) sess.save(item2) sess.flush() - + red_keyword = item1.keywords[1].keyword del item1.keywords[0] del item1.keywords[0] @@ -112,16 +113,17 @@ class AssociationTest(PersistTest): item2.keywords.append(KeywordAssociation(purple_keyword, 'purple_item2_assoc')) item1.keywords.append(KeywordAssociation(purple_keyword, 'purple_item1_assoc')) item1.keywords.append(KeywordAssociation(Keyword('yellow'), 'yellow_assoc')) - + sess.flush() saved = repr([item1, item2]) sess.clear() - l = sess.query(Item).select() + l = sess.query(Item).all() loaded = repr(l) print saved print loaded self.assert_(saved == loaded) + @testing.uses_deprecated('association option') def testdelete(self): sess = create_session() item1 = Item('item1') @@ -139,10 +141,10 @@ class AssociationTest(PersistTest): sess.flush() self.assert_(item_keywords.count().scalar() == 0) -class AssociationTest2(PersistTest): +class AssociationTest2(TestBase): def setUpAll(self): global table_originals, table_people, table_isauthor, metadata, Originals, People, IsAuthor - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table_originals = Table('Originals', metadata, Column('ID', Integer, primary_key=True), Column('Title', String(200), nullable=False), @@ -154,9 +156,9 @@ class AssociationTest2(PersistTest): Column('Country', CHAR(2), default='es'), ) table_isauthor = Table('IsAuthor', metadata, - Column('OriginalsID', Integer, ForeignKey('Originals.ID'), + Column('OriginalsID', Integer, ForeignKey('Originals.ID'), default=None), - Column('PeopleID', Integer, ForeignKey('People.ID'), + Column('PeopleID', Integer, ForeignKey('People.ID'), default=None), Column('Kind', CHAR(1), default='A'), ) @@ -167,7 +169,7 @@ default=None), for k,v in kw.iteritems(): setattr(self, k, v) def display(self): - c = [ "%s=%s" % (col.key, repr(getattr(self, col.key))) for col + c = [ "%s=%s" % (col.key, repr(getattr(self, col.key))) for col in self.c ] return "%s(%s)" % (self.__class__.__name__, ', '.join(c)) def __repr__(self): @@ -185,7 +187,7 @@ in self.c ] properties={ 'people': relation(IsAuthor, association=People), 'authors': relation(People, secondary=table_isauthor, backref='written', - primaryjoin=and_(table_originals.c.ID==table_isauthor.c.OriginalsID, + primaryjoin=and_(table_originals.c.ID==table_isauthor.c.OriginalsID, table_isauthor.c.Kind=='A')), 'title': table_originals.c.Title, 'date': table_originals.c.Date, @@ -195,9 +197,9 @@ in self.c ] 'name': table_people.c.Name, 'country': table_people.c.Country, }) - mapper(IsAuthor, table_isauthor, - primary_key=[table_isauthor.c.OriginalsID, table_isauthor.c.PeopleID, -table_isauthor.c.Kind], + mapper(IsAuthor, table_isauthor, + primary_key=[table_isauthor.c.OriginalsID, table_isauthor.c.PeopleID, +table_isauthor.c.Kind], properties={ 'original': relation(Originals, lazy=False), 'person': relation(People, lazy=False), @@ -219,6 +221,6 @@ table_isauthor.c.Kind], sess.flush() - + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py index 652186b8e6..af3fcbc7bb 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/assorted_eager.py @@ -1,45 +1,53 @@ """eager loading unittests derived from mailing list-reported problems and trac tickets.""" -import testbase +import testenv; testenv.configure_for_tests() import random, datetime from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext from testlib import * +from testlib import fixtures -class EagerTest(AssertMixin): +class EagerTest(TestBase, AssertsExecutionResults): def setUpAll(self): global dbmeta, owners, categories, tests, options, Owner, Category, Test, Option, false - dbmeta = MetaData(testbase.db) - + dbmeta = MetaData(testing.db) + # determine a literal value for "false" based on the dialect - false = Boolean().dialect_impl(testbase.db.dialect).convert_bind_param(False, testbase.db.dialect) - + # FIXME: this PassiveDefault setup is bogus. + bp = Boolean().dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect) + if bp: + false = str(bp(False)) + elif testing.against('maxdb'): + false = text('FALSE') + else: + false = str(False) + owners = Table ( 'owners', dbmeta , - Column ( 'id', Integer, primary_key=True, nullable=False ), - Column('data', String(30)) ) + Column ( 'id', Integer, primary_key=True, nullable=False ), + Column('data', String(30)) ) categories=Table( 'categories', dbmeta, - Column ( 'id', Integer,primary_key=True, nullable=False ), - Column ( 'name', VARCHAR(20), index=True ) ) + Column ( 'id', Integer,primary_key=True, nullable=False ), + Column ( 'name', VARCHAR(20), index=True ) ) tests = Table ( 'tests', dbmeta , - Column ( 'id', Integer, primary_key=True, nullable=False ), - Column ( 'owner_id',Integer, ForeignKey('owners.id'), nullable=False,index=True ), - Column ( 'category_id', Integer, ForeignKey('categories.id'),nullable=False,index=True )) + Column ( 'id', Integer, primary_key=True, nullable=False ), + Column ( 'owner_id',Integer, ForeignKey('owners.id'), nullable=False,index=True ), + Column ( 'category_id', Integer, ForeignKey('categories.id'),nullable=False,index=True )) options = Table ( 'options', dbmeta , - Column ( 'test_id', Integer, ForeignKey ( 'tests.id' ), primary_key=True, nullable=False ), - Column ( 'owner_id', Integer, ForeignKey ( 'owners.id' ), primary_key=True, nullable=False ), - Column ( 'someoption', Boolean, PassiveDefault(str(false)), nullable=False ) ) + Column ( 'test_id', Integer, ForeignKey ( 'tests.id' ), primary_key=True, nullable=False ), + Column ( 'owner_id', Integer, ForeignKey ( 'owners.id' ), primary_key=True, nullable=False ), + Column ( 'someoption', Boolean, PassiveDefault(false), nullable=False ) ) dbmeta.create_all() class Owner(object): - pass + pass class Category(object): - pass + pass class Test(object): - pass + pass class Option(object): - pass + pass mapper(Owner,owners) mapper(Category,categories) mapper(Option,options,properties={'owner':relation(Owner),'test':relation(Test)}) @@ -47,8 +55,8 @@ class EagerTest(AssertMixin): 'owner':relation(Owner,backref='tests'), 'category':relation(Category), 'owner_option': relation(Option,primaryjoin=and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id), - foreignkey=[options.c.test_id, options.c.owner_id], - uselist=False ) + foreign_keys=[options.c.test_id, options.c.owner_id], + uselist=False ) }) s=create_session() @@ -63,17 +71,17 @@ class EagerTest(AssertMixin): s.save(c) for i in range(3): - t=Test() - t.owner=o - t.category=c - s.save(t) - if i==1: - op=Option() - op.someoption=True - t.owner_option=op - if i==2: - op=Option() - t.owner_option=op + t=Test() + t.owner=o + t.category=c + s.save(t) + if i==1: + op=Option() + op.someoption=True + t.owner_option=op + if i==2: + op=Option() + t.owner_option=op s.flush() s.close() @@ -84,7 +92,7 @@ class EagerTest(AssertMixin): def test_noorm(self): """test the control case""" - # I want to display a list of tests owned by owner 1 + # I want to display a list of tests owned by owner 1 # if someoption is false or he hasn't specified it yet (null) # but not if he set it to true (example someoption is for hiding) @@ -96,27 +104,61 @@ class EagerTest(AssertMixin): # not orm style correct query print "Obtaining correct results without orm" result = select( [tests.c.id,categories.c.name], - and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)), - order_by=[tests.c.id], - from_obj=[tests.join(categories).outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))] ).execute().fetchall() + and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)), + order_by=[tests.c.id], + from_obj=[tests.join(categories).outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))] ).execute().fetchall() print result assert result == [(1, u'Some Category'), (3, u'Some Category')] - + def test_withouteagerload(self): + s = create_session() + l = (s.query(Test). + select_from(tests.outerjoin(options, + and_(tests.c.id == options.c.test_id, + tests.c.owner_id == + options.c.owner_id))). + filter(and_(tests.c.owner_id==1, + or_(options.c.someoption==None, + options.c.someoption==False)))) + + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + print result + assert result == [u'1 Some Category', u'3 Some Category'] + + @testing.uses_deprecated('//select') + def test_withouteagerload_deprecated(self): s = create_session() l=s.query(Test).select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)), - from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))]) + from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))]) result = ["%d %s" % ( t.id,t.category.name ) for t in l] print result assert result == [u'1 Some Category', u'3 Some Category'] def test_witheagerload(self): - """test that an eagerload locates the correct "from" clause with + """test that an eagerload locates the correct "from" clause with which to attach to, when presented with a query that already has a complicated from clause.""" s = create_session() q=s.query(Test).options(eagerload('category')) + + l=(q.select_from(tests.outerjoin(options, + and_(tests.c.id == + options.c.test_id, + tests.c.owner_id == + options.c.owner_id))). + filter(and_(tests.c.owner_id==1,or_(options.c.someoption==None, + options.c.someoption==False)))) + + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + print result + assert result == [u'1 Some Category', u'3 Some Category'] + + @testing.uses_deprecated('//select') + def test_witheagerload_deprecated(self): + """As test_witheagerload, but via select().""" + s = create_session() + q=s.query(Test).options(eagerload('category')) l=q.select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)), - from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))]) + from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))]) result = ["%d %s" % ( t.id,t.category.name ) for t in l] print result assert result == [u'1 Some Category', u'3 Some Category'] @@ -125,34 +167,58 @@ class EagerTest(AssertMixin): """test the same as witheagerload except using generative""" s = create_session() q=s.query(Test).options(eagerload('category')) - l=q.filter ( + l=q.filter ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)) ).outerjoin('owner_option') - + result = ["%d %s" % ( t.id,t.category.name ) for t in l] print result assert result == [u'1 Some Category', u'3 Some Category'] + @testing.unsupported('sybase') def test_withoutouterjoin_literal(self): + s = create_session() + q = s.query(Test).options(eagerload('category')) + l = (q.filter( + (tests.c.owner_id==1) & + ('options.someoption is null or options.someoption=%s' % false)). + join('owner_option')) + + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + print result + assert result == [u'3 Some Category'] + + @testing.unsupported('sybase') + @testing.uses_deprecated('//select', '//join_to') + def test_withoutouterjoin_literal_deprecated(self): s = create_session() q=s.query(Test).options(eagerload('category')) l=q.select( (tests.c.owner_id==1) & ('options.someoption is null or options.someoption=%s' % false) & q.join_to('owner_option') ) - result = ["%d %s" % ( t.id,t.category.name ) for t in l] + result = ["%d %s" % ( t.id,t.category.name ) for t in l] print result assert result == [u'3 Some Category'] def test_withoutouterjoin(self): + s = create_session() + q=s.query(Test).options(eagerload('category')) + l = q.filter( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) ).join('owner_option') + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + print result + assert result == [u'3 Some Category'] + + @testing.uses_deprecated('//select', '//join_to', '//join_via') + def test_withoutouterjoin_deprecated(self): s = create_session() q=s.query(Test).options(eagerload('category')) l=q.select( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) & q.join_to('owner_option') ) - result = ["%d %s" % ( t.id,t.category.name ) for t in l] + result = ["%d %s" % ( t.id,t.category.name ) for t in l] print result assert result == [u'3 Some Category'] -class EagerTest2(AssertMixin): +class EagerTest2(TestBase, AssertsExecutionResults): def setUpAll(self): global metadata, middle, left, right - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) middle = Table('middle', metadata, Column('id', Integer, primary_key = True), Column('data', String(50)), @@ -173,10 +239,11 @@ class EagerTest2(AssertMixin): def tearDown(self): for t in metadata.table_iterator(reverse=True): t.delete().execute() - + + @testing.fails_on('maxdb') def testeagerterminate(self): """test that eager query generation does not include the same mapper's table twice. - + or, that bi-directional eager loads dont include each other in eager query generation.""" class Middle(object): def __init__(self, data): self.data = data @@ -193,14 +260,14 @@ class EagerTest2(AssertMixin): 'right': relation(Right, lazy=False, backref=backref('middle', lazy=False)), } ) - session = create_session(bind=testbase.db) + session = create_session(bind=testing.db) p = Middle('test1') p.left.append(Left('tag1')) p.right.append(Right('tag2')) session.save(p) session.flush() session.clear() - obj = session.query(Left).get_by(tag='tag1') + obj = session.query(Left).filter_by(tag='tag1').one() print obj.middle.right[0] class EagerTest3(ORMTest): @@ -219,7 +286,8 @@ class EagerTest3(ORMTest): Column ( 'id', Integer, primary_key=True, nullable=False ), Column ( 'data_id', Integer, ForeignKey('datas.id')), Column ( 'somedata', Integer, nullable=False )) - + + @testing.fails_on('maxdb') def test_nesting_with_functions(self): class Data(object): pass class Foo(object):pass @@ -236,7 +304,7 @@ class EagerTest3(ORMTest): d.a=x s.save(d) data.append(d) - + for x in range(10): rid=random.randint(0,len(data) - 1) somedata=random.randint(1,50000) @@ -251,22 +319,23 @@ class EagerTest3(ORMTest): [stats.c.data_id,func.max(stats.c.somedata).label('max')], stats.c.data_id<=25, group_by=[stats.c.data_id]).alias('arb') - + arb_result = arb_data.execute().fetchall() # order the result list descending based on 'max' arb_result.sort(lambda a, b:cmp(b['max'],a['max'])) # extract just the "data_id" from it arb_result = [row['data_id'] for row in arb_result] - - # now query for Data objects using that above select, adding the + + # now query for Data objects using that above select, adding the # "order by max desc" separately - q=s.query(Data).options(eagerload('foo')).select( - from_obj=[datas.join(arb_data,arb_data.c.data_id==datas.c.id)], - order_by=[desc(arb_data.c.max)],limit=10) - + q=(s.query(Data).options(eagerload('foo')). + select_from(datas.join(arb_data,arb_data.c.data_id==datas.c.id)). + order_by(desc(arb_data.c.max)). + limit(10)) + # extract "data_id" from the list of result objects verify_result = [d.id for d in q] - + # assert equality including ordering (may break if the DB "ORDER BY" and python's sort() used differing # algorithms and there are repeated 'somedata' values in the list) assert verify_result == arb_result @@ -278,12 +347,13 @@ class EagerTest4(ORMTest): Column('department_id', Integer, primary_key=True), Column('name', String(50))) - employees = Table('employees', metadata, + employees = Table('employees', metadata, Column('person_id', Integer, primary_key=True), Column('name', String(50)), Column('department_id', Integer, ForeignKey('departments.department_id'))) + @testing.fails_on('maxdb') def test_basic(self): class Department(object): def __init__(self, **kwargs): @@ -324,13 +394,13 @@ class EagerTest4(ORMTest): assert q[0] is d2 class EagerTest5(ORMTest): - """test the construction of AliasedClauses for the same eager load property but different + """test the construction of AliasedClauses for the same eager load property but different parent mappers, due to inheritance""" def define_tables(self, metadata): global base, derived, derivedII, comments base = Table( 'base', metadata, - Column('uid', String(30), primary_key=True), + Column('uid', String(30), primary_key=True), Column('x', String(30)) ) @@ -390,17 +460,17 @@ class EagerTest5(ORMTest): derivedMapper = mapper(Derived, derived, inherits=baseMapper) derivedIIMapper = mapper(DerivedII, derivedII, inherits=baseMapper) sess = create_session() - d = Derived(1, 'x', 'y') - d.comments = [Comment(1, 'comment')] - d2 = DerivedII(2, 'xx', 'z') - d2.comments = [Comment(2, 'comment')] + d = Derived('uid1', 'x', 'y') + d.comments = [Comment('uid1', 'comment')] + d2 = DerivedII('uid2', 'xx', 'z') + d2.comments = [Comment('uid2', 'comment')] sess.save(d) sess.save(d2) sess.flush() sess.clear() # this eager load sets up an AliasedClauses for the "comment" relationship, # then stores it in clauses_by_lead_mapper[mapper for Derived] - d = sess.query(Derived).get(1) + d = sess.query(Derived).get('uid1') sess.clear() assert len([c for c in d.comments]) == 1 @@ -408,7 +478,7 @@ class EagerTest5(ORMTest): # and should store it in clauses_by_lead_mapper[mapper for DerivedII]. # the bug was that the previous AliasedClause create prevented this population # from occurring. - d2 = sess.query(DerivedII).get(2) + d2 = sess.query(DerivedII).get('uid2') sess.clear() # object is not in the session; therefore the lazy load cant trigger here, # eager load had to succeed @@ -417,24 +487,24 @@ class EagerTest5(ORMTest): class EagerTest6(ORMTest): def define_tables(self, metadata): global designType, design, part, inheritedPart - designType = Table('design_types', metadata, - Column('design_type_id', Integer, primary_key=True), - ) + designType = Table('design_types', metadata, + Column('design_type_id', Integer, primary_key=True), + ) - design =Table('design', metadata, - Column('design_id', Integer, primary_key=True), - Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) + design =Table('design', metadata, + Column('design_id', Integer, primary_key=True), + Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) - part = Table('parts', metadata, - Column('part_id', Integer, primary_key=True), - Column('design_id', Integer, ForeignKey('design.design_id')), - Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) + part = Table('parts', metadata, + Column('part_id', Integer, primary_key=True), + Column('design_id', Integer, ForeignKey('design.design_id')), + Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) inheritedPart = Table('inherited_part', metadata, - Column('ip_id', Integer, primary_key=True), - Column('part_id', Integer, ForeignKey('parts.part_id')), - Column('design_id', Integer, ForeignKey('design.design_id')), - ) + Column('ip_id', Integer, primary_key=True), + Column('part_id', Integer, ForeignKey('parts.part_id')), + Column('design_id', Integer, ForeignKey('design.design_id')), + ) def testone(self): class Part(object):pass @@ -445,20 +515,23 @@ class EagerTest6(ORMTest): mapper(Part, part) mapper(InheritedPart, inheritedPart, properties=dict( - part=relation(Part, lazy=False) + part=relation(Part, lazy=False) )) mapper(Design, design, properties=dict( - parts=relation(Part, private=True, backref="design"), - inheritedParts=relation(InheritedPart, private=True, backref="design"), + inheritedParts=relation(InheritedPart, + cascade="all, delete-orphan", + backref="design"), )) mapper(DesignType, designType, properties=dict( - # designs=relation(Design, private=True, backref="type"), + # designs=relation(Design, private=True, backref="type"), )) class_mapper(Design).add_property("type", relation(DesignType, lazy=False, backref="designs")) - class_mapper(Part).add_property("design", relation(Design, lazy=False, backref="parts")) + + class_mapper(Part).add_property("design", relation(Design, lazy=False, backref=backref("parts", cascade="all, delete-orphan"))) + #Part.mapper.add_property("designType", relation(DesignType)) d = Design() @@ -470,6 +543,7 @@ class EagerTest6(ORMTest): x.inheritedParts class EagerTest7(ORMTest): + @testing.uses_deprecated('SessionContext') def define_tables(self, metadata): global companies_table, addresses_table, invoice_table, phones_table, items_table, ctx global Company, Address, Phone, Item,Invoice @@ -498,7 +572,7 @@ class EagerTest7(ORMTest): invoice_table = Table('invoices', metadata, Column('invoice_id', Integer, Sequence('invoice_id_seq', optional=True), primary_key = True), Column('company_id', Integer, ForeignKey("companies.company_id")), - Column('date', DateTime), + Column('date', DateTime), ) items_table = Table('items', metadata, @@ -532,8 +606,9 @@ class EagerTest7(ORMTest): def __repr__(self): return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty) + @testing.uses_deprecated('SessionContext') def testone(self): - """tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated + """tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated the bug, which is that when the single Company is loaded, no further processing of the rows occurred in order to load the Company's second Address object.""" @@ -642,7 +717,7 @@ class EagerTest7(ORMTest): # set up an invoice i1 = Invoice() i1.date = datetime.datetime.now() - i1.company = c1 + i1.company = a item1 = Item() item1.code = 'aaaa' @@ -714,11 +789,12 @@ class EagerTest8(ORMTest): ) def setUp(self): - testbase.db.execute("INSERT INTO prj (title) values('project 1');") - testbase.db.execute("INSERT INTO task_status (id) values(1);") - testbase.db.execute("INSERT INTO task_type(id) values(1);") - testbase.db.execute("INSERT INTO task (title, task_type_id, status_id, prj_id) values('task 1',1,1,1);") + testing.db.execute(project_t.insert(), {'id':1}) + testing.db.execute(task_status_t.insert(), {'id':1}) + testing.db.execute(task_type_t.insert(), {'id':1}) + testing.db.execute(task_t.insert(), {'title':u'task 1', 'task_type_id':1, 'status_id':1, 'prj_id':1}) + @testing.fails_on('maxdb') def test_nested_joins(self): # this is testing some subtle column resolution stuff, # concerning corresponding_column() being extremely accurate @@ -731,7 +807,7 @@ class EagerTest8(ORMTest): tsk_cnt_join = outerjoin(project_t, task_t, task_t.c.prj_id==project_t.c.id) - ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')], + ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')], from_obj=[tsk_cnt_join], group_by=[project_t.c.id]).alias('prj_tsk_cnt_s') j = join(project_t, ss, project_t.c.id == ss.c.prj_id) @@ -743,12 +819,12 @@ class EagerTest8(ORMTest): mapper(Message_Type, message_type_t) - mapper(Message, message_t, + mapper(Message, message_t, properties=dict(type=relation(Message_Type, lazy=False, uselist=False), )) tsk_cnt_join = outerjoin(project_t, task_t, task_t.c.prj_id==project_t.c.id) - ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')], + ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')], from_obj=[tsk_cnt_join], group_by=[project_t.c.id]).alias('prj_tsk_cnt_s') j = join(project_t, ss, project_t.c.id == ss.c.prj_id) @@ -766,9 +842,87 @@ class EagerTest8(ORMTest): session = create_session() - for t in session.query(cls.mapper).limit(10).offset(0).list(): - print t.id, t.title, t.props_cnt - - -if __name__ == "__main__": - testbase.main() + for t in session.query(cls.mapper).limit(10).offset(0).all(): + print t.id, t.title, t.props_cnt + +class EagerTest9(ORMTest): + """test the usage of query options to eagerly load specific paths. + + this relies upon the 'path' construct used by PropertyOption to relate + LoaderStrategies to specific paths, as well as the path state maintained + throughout the query setup/mapper instances process. + """ + + def define_tables(self, metadata): + global accounts_table, transactions_table, entries_table + accounts_table = Table('accounts', metadata, + Column('account_id', Integer, primary_key=True), + Column('name', String(40)), + ) + transactions_table = Table('transactions', metadata, + Column('transaction_id', Integer, primary_key=True), + Column('name', String(40)), + ) + entries_table = Table('entries', metadata, + Column('entry_id', Integer, primary_key=True), + Column('name', String(40)), + Column('account_id', Integer, ForeignKey(accounts_table.c.account_id)), + Column('transaction_id', Integer, ForeignKey(transactions_table.c.transaction_id)), + ) + + @testing.fails_on('maxdb') + def test_eagerload_on_path(self): + class Account(fixtures.Base): + pass + + class Transaction(fixtures.Base): + pass + + class Entry(fixtures.Base): + pass + + mapper(Account, accounts_table) + mapper(Transaction, transactions_table) + mapper(Entry, entries_table, properties = dict( + account = relation(Account, uselist=False, backref=backref('entries', lazy=True)), + transaction = relation(Transaction, uselist=False, backref=backref('entries', lazy=False)), + )) + + session = create_session() + + tx1 = Transaction(name='tx1') + tx2 = Transaction(name='tx2') + + acc1 = Account(name='acc1') + ent11 = Entry(name='ent11', account=acc1, transaction=tx1) + ent12 = Entry(name='ent12', account=acc1, transaction=tx2) + + acc2 = Account(name='acc2') + ent21 = Entry(name='ent21', account=acc2, transaction=tx1) + ent22 = Entry(name='ent22', account=acc2, transaction=tx2) + + session.save(acc1) + session.flush() + session.clear() + + def go(): + # load just the first Account. eager loading will actually load all objects saved thus far, + # but will not eagerly load the "accounts" off the immediate "entries"; only the + # "accounts" off the entries->transaction->entries + acc = session.query(Account).options(eagerload_all('entries.transaction.entries.account')).first() + + # no sql occurs + assert acc.name == 'acc1' + assert acc.entries[0].transaction.entries[0].account.name == 'acc1' + assert acc.entries[0].transaction.entries[1].account.name == 'acc2' + + # lazyload triggers but no sql occurs because many-to-one uses cached query.get() + for e in acc.entries: + assert e.account is acc + + self.assert_sql_count(testing.db, go, 1) + + + +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 9b5f738bf7..caa129e5ea 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -1,56 +1,54 @@ -import testbase +import testenv; testenv.configure_for_tests() import pickle import sqlalchemy.orm.attributes as attributes from sqlalchemy.orm.collections import collection from sqlalchemy import exceptions from testlib import * +from testlib import fixtures +ROLLBACK_SUPPORTED=False + +# these test classes defined at the module +# level to support pickling class MyTest(object):pass class MyTest2(object):pass - -class AttributesTest(PersistTest): - """tests for the attributes.py module, which deals with tracking attribute changes on an object.""" - def testbasic(self): + +class AttributesTest(TestBase): + + def test_basic(self): class User(object):pass - manager = attributes.AttributeManager() - manager.register_attribute(User, 'user_id', uselist = False) - manager.register_attribute(User, 'user_name', uselist = False) - manager.register_attribute(User, 'email_address', uselist = False) - + + attributes.register_class(User) + attributes.register_attribute(User, 'user_id', uselist = False, useobject=False) + attributes.register_attribute(User, 'user_name', uselist = False, useobject=False) + attributes.register_attribute(User, 'email_address', uselist = False, useobject=False) + u = User() - print repr(u.__dict__) - u.user_id = 7 u.user_name = 'john' u.email_address = 'lala@123.com' - - print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - manager.commit(u) - print repr(u.__dict__) + u._state.commit_all() self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') u.user_name = 'heythere' u.email_address = 'foo@bar.com' - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com') - - manager.rollback(u) - print repr(u.__dict__) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - def testpickleness(self): - manager = attributes.AttributeManager() - manager.register_attribute(MyTest, 'user_id', uselist = False) - manager.register_attribute(MyTest, 'user_name', uselist = False) - manager.register_attribute(MyTest, 'email_address', uselist = False) - manager.register_attribute(MyTest2, 'a', uselist = False) - manager.register_attribute(MyTest2, 'b', uselist = False) + def test_pickleness(self): + attributes.register_class(MyTest) + attributes.register_class(MyTest2) + attributes.register_attribute(MyTest, 'user_id', uselist = False, useobject=False) + attributes.register_attribute(MyTest, 'user_name', uselist = False, useobject=False) + attributes.register_attribute(MyTest, 'email_address', uselist = False, useobject=False) + attributes.register_attribute(MyTest2, 'a', uselist = False, useobject=False) + attributes.register_attribute(MyTest2, 'b', uselist = False, useobject=False) # shouldnt be pickling callables at the class level def somecallable(*args): return None attr_name = 'mt2' - manager.register_attribute(MyTest, attr_name, uselist = True, trackparent=True, callable_=somecallable) + attributes.register_attribute(MyTest, attr_name, uselist = True, trackparent=True, callable_=somecallable, useobject=True) o = MyTest() o.mt2.append(MyTest2()) @@ -59,21 +57,21 @@ class AttributesTest(PersistTest): pk_o = pickle.dumps(o) o2 = pickle.loads(pk_o) + pk_o2 = pickle.dumps(o2) # so... pickle is creating a new 'mt2' string after a roundtrip here, - # so we'll brute-force set it to be id-equal to the original string - o_mt2_str = [ k for k in o.__dict__ if k == 'mt2'][0] - o2_mt2_str = [ k for k in o2.__dict__ if k == 'mt2'][0] - self.assert_(o_mt2_str == o2_mt2_str) - self.assert_(o_mt2_str is not o2_mt2_str) - # change the id of o2.__dict__['mt2'] - former = o2.__dict__['mt2'] - del o2.__dict__['mt2'] - o2.__dict__[o_mt2_str] = former - - pk_o2 = pickle.dumps(o2) + # so we'll brute-force set it to be id-equal to the original string + if False: + o_mt2_str = [ k for k in o.__dict__ if k == 'mt2'][0] + o2_mt2_str = [ k for k in o2.__dict__ if k == 'mt2'][0] + self.assert_(o_mt2_str == o2_mt2_str) + self.assert_(o_mt2_str is not o2_mt2_str) + # change the id of o2.__dict__['mt2'] + former = o2.__dict__['mt2'] + del o2.__dict__['mt2'] + o2.__dict__[o_mt2_str] = former - self.assert_(pk_o == pk_o2) + self.assert_(pk_o == pk_o2) # the above is kind of distrurbing, so let's do it again a little # differently. the string-id in serialization thing is just an @@ -97,19 +95,78 @@ class AttributesTest(PersistTest): self.assert_(o4.mt2[0].a == 'abcde') self.assert_(o4.mt2[0].b is None) - def testlist(self): + def test_deferred(self): + class Foo(object):pass + + data = {'a':'this is a', 'b':12} + def loader(instance, keys): + for k in keys: + instance.__dict__[k] = data[k] + return attributes.ATTR_WAS_SET + + attributes.register_class(Foo, deferred_scalar_loader=loader) + attributes.register_attribute(Foo, 'a', uselist=False, useobject=False) + attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) + + f = Foo() + f._state.expire_attributes(None) + self.assertEquals(f.a, "this is a") + self.assertEquals(f.b, 12) + + f.a = "this is some new a" + f._state.expire_attributes(None) + self.assertEquals(f.a, "this is a") + self.assertEquals(f.b, 12) + + f._state.expire_attributes(None) + f.a = "this is another new a" + self.assertEquals(f.a, "this is another new a") + self.assertEquals(f.b, 12) + + f._state.expire_attributes(None) + self.assertEquals(f.a, "this is a") + self.assertEquals(f.b, 12) + + del f.a + self.assertEquals(f.a, None) + self.assertEquals(f.b, 12) + + f._state.commit_all() + self.assertEquals(f.a, None) + self.assertEquals(f.b, 12) + + def test_deferred_pickleable(self): + data = {'a':'this is a', 'b':12} + def loader(instance, keys): + for k in keys: + instance.__dict__[k] = data[k] + return attributes.ATTR_WAS_SET + + attributes.register_class(MyTest, deferred_scalar_loader=loader) + attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False) + attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False) + + m = MyTest() + m._state.expire_attributes(None) + assert 'a' not in m.__dict__ + m2 = pickle.loads(pickle.dumps(m)) + assert 'a' not in m2.__dict__ + self.assertEquals(m2.a, "this is a") + self.assertEquals(m2.b, 12) + + def test_list(self): class User(object):pass class Address(object):pass - manager = attributes.AttributeManager() - manager.register_attribute(User, 'user_id', uselist = False) - manager.register_attribute(User, 'user_name', uselist = False) - manager.register_attribute(User, 'addresses', uselist = True) - manager.register_attribute(Address, 'address_id', uselist = False) - manager.register_attribute(Address, 'email_address', uselist = False) - - u = User() - print repr(u.__dict__) + attributes.register_class(User) + attributes.register_class(Address) + attributes.register_attribute(User, 'user_id', uselist = False, useobject=False) + attributes.register_attribute(User, 'user_name', uselist = False, useobject=False) + attributes.register_attribute(User, 'addresses', uselist = True, useobject=True) + attributes.register_attribute(Address, 'address_id', uselist = False, useobject=False) + attributes.register_attribute(Address, 'email_address', uselist = False, useobject=False) + + u = User() u.user_id = 7 u.user_name = 'john' u.addresses = [] @@ -118,10 +175,8 @@ class AttributesTest(PersistTest): a.email_address = 'lala@123.com' u.addresses.append(a) - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - manager.commit(u, a) - print repr(u.__dict__) + u, a._state.commit_all() self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') u.user_name = 'heythere' @@ -129,118 +184,51 @@ class AttributesTest(PersistTest): a.address_id = 11 a.email_address = 'foo@bar.com' u.addresses.append(a) - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com') - manager.rollback(u, a) - print repr(u.__dict__) - print repr(u.addresses[0].__dict__) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1) - - def testbackref(self): - class Student(object):pass - class Course(object):pass - manager = attributes.AttributeManager() - manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students')) - manager.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses')) - - s = Student() - c = Course() - s.courses.append(c) - self.assert_(c.students == [s]) - s.courses.remove(c) - self.assert_(c.students == []) - - (s1, s2, s3) = (Student(), Student(), Student()) - - c.students = [s1, s2, s3] - self.assert_(s2.courses == [c]) - self.assert_(s1.courses == [c]) - print "--------------------------------" - print s1 - print s1.courses - print c - print c.students - s1.courses.remove(c) - self.assert_(c.students == [s2,s3]) - class Post(object):pass - class Blog(object):pass - - manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True) - manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True) - b = Blog() - (p1, p2, p3) = (Post(), Post(), Post()) - b.posts.append(p1) - b.posts.append(p2) - b.posts.append(p3) - self.assert_(b.posts == [p1, p2, p3]) - self.assert_(p2.blog is b) - - p3.blog = None - self.assert_(b.posts == [p1, p2]) - p4 = Post() - p4.blog = b - self.assert_(b.posts == [p1, p2, p4]) - - p4.blog = b - p4.blog = b - self.assert_(b.posts == [p1, p2, p4]) - - - class Port(object):pass - class Jack(object):pass - manager.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port')) - manager.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack')) - p = Port() - j = Jack() - p.jack = j - self.assert_(j.port is p) - self.assert_(p.jack is not None) - - j.port = None - self.assert_(p.jack is None) - - def testlazytrackparent(self): + def test_lazytrackparent(self): """test that the "hasparent" flag works properly when lazy loaders and backrefs are used""" - manager = attributes.AttributeManager() class Post(object):pass class Blog(object):pass + attributes.register_class(Post) + attributes.register_class(Blog) - # set up instrumented attributes with backrefs - manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True) - manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True) + # set up instrumented attributes with backrefs + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) # create objects as if they'd been freshly loaded from the database (without history) b = Blog() p1 = Post() - Blog.posts.set_callable(b, lambda:[p1]) - Post.blog.set_callable(p1, lambda:b) - manager.commit(p1, b) + b._state.set_callable('posts', lambda:[p1]) + p1._state.set_callable('blog', lambda:b) + p1, b._state.commit_all() # no orphans (called before the lazy loaders fire off) - assert getattr(Blog, 'posts').hasparent(p1, optimistic=True) - assert getattr(Post, 'blog').hasparent(b, optimistic=True) + assert attributes.has_parent(Blog, p1, 'posts', optimistic=True) + assert attributes.has_parent(Post, b, 'blog', optimistic=True) # assert connections assert p1.blog is b assert p1 in b.posts - + # manual connections b2 = Blog() p2 = Post() b2.posts.append(p2) - assert getattr(Blog, 'posts').hasparent(p2) - assert getattr(Post, 'blog').hasparent(b2) - - def testinheritance(self): + assert attributes.has_parent(Blog, p2, 'posts') + assert attributes.has_parent(Post, b2, 'blog') + + def test_inheritance(self): """tests that attributes are polymorphic""" class Foo(object):pass class Bar(Foo):pass - - manager = attributes.AttributeManager() - + + + attributes.register_class(Foo) + attributes.register_class(Bar) + def func1(): print "func1" return "this is the foo attr" @@ -250,10 +238,10 @@ class AttributesTest(PersistTest): def func3(): print "func3" return "this is the shared attr" - manager.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1) - manager.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3) - manager.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2) - + attributes.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1, useobject=True) + attributes.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3, useobject=True) + attributes.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2, useobject=True) + x = Foo() y = Bar() assert x.element == 'this is the foo attr' @@ -261,96 +249,130 @@ class AttributesTest(PersistTest): assert x.element2 == 'this is the shared attr' assert y.element2 == 'this is the shared attr' - def testinheritance2(self): + def test_no_double_state(self): + states = set() + class Foo(object): + def __init__(self): + states.add(self._state) + class Bar(Foo): + def __init__(self): + states.add(self._state) + Foo.__init__(self) + + + attributes.register_class(Foo) + attributes.register_class(Bar) + + b = Bar() + self.assertEquals(len(states), 1) + self.assertEquals(list(states)[0].obj(), b) + + + def test_inheritance2(self): """test that the attribute manager can properly traverse the managed attributes of an object, if the object is of a descendant class with managed attributes in the parent class""" class Foo(object):pass class Bar(Foo):pass - manager = attributes.AttributeManager() - manager.register_attribute(Foo, 'element', uselist=False) + + class Element(object): + _state = True + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'element', uselist=False, useobject=True) + el = Element() x = Bar() - x.element = 'this is the element' - hist = manager.get_history(x, 'element') - assert hist.added_items() == ['this is the element'] - manager.commit(x) - hist = manager.get_history(x, 'element') - assert hist.added_items() == [] - assert hist.unchanged_items() == ['this is the element'] - - def testlazyhistory(self): + x.element = el + self.assertEquals(attributes.get_history(x._state, 'element'), ([el],[], [])) + x._state.commit_all() + + (added, unchanged, deleted) = attributes.get_history(x._state, 'element') + assert added == [] + assert unchanged == [el] + + def test_lazyhistory(self): """tests that history functions work with lazy-loading attributes""" - class Foo(object):pass - class Bar(object): - def __init__(self, id): - self.id = id - def __repr__(self): - return "Bar: id %d" % self.id - - manager = attributes.AttributeManager() + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_class(Bar) + + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] def func1(): return "this is func 1" def func2(): - return [Bar(1), Bar(2), Bar(3)] + return [bar1, bar2, bar3] - manager.register_attribute(Foo, 'col1', uselist=False, callable_=lambda o:func1) - manager.register_attribute(Foo, 'col2', uselist=True, callable_=lambda o:func2) - manager.register_attribute(Bar, 'id', uselist=False) + attributes.register_attribute(Foo, 'col1', uselist=False, callable_=lambda o:func1, useobject=True) + attributes.register_attribute(Foo, 'col2', uselist=True, callable_=lambda o:func2, useobject=True) + attributes.register_attribute(Bar, 'id', uselist=False, useobject=True) x = Foo() - manager.commit(x) - x.col2.append(Bar(4)) - h = manager.get_history(x, 'col2') - print h.added_items() - print h.unchanged_items() + x._state.commit_all() + x.col2.append(bar4) + self.assertEquals(attributes.get_history(x._state, 'col2'), ([bar4], [bar1, bar2, bar3], [])) - - def testparenttrack(self): + def test_parenttrack(self): class Foo(object):pass class Bar(object):pass - - manager = attributes.AttributeManager() - - manager.register_attribute(Foo, 'element', uselist=False, trackparent=True) - manager.register_attribute(Bar, 'element', uselist=False, trackparent=True) - + + attributes.register_class(Foo) + attributes.register_class(Bar) + + attributes.register_attribute(Foo, 'element', uselist=False, trackparent=True, useobject=True) + attributes.register_attribute(Bar, 'element', uselist=False, trackparent=True, useobject=True) + f1 = Foo() f2 = Foo() b1 = Bar() b2 = Bar() - + f1.element = b1 b2.element = f2 - - assert manager.get_history(f1, 'element').hasparent(b1) - assert not manager.get_history(f1, 'element').hasparent(b2) - assert not manager.get_history(f1, 'element').hasparent(f2) - assert manager.get_history(b2, 'element').hasparent(f2) - + + assert attributes.has_parent(Foo, b1, 'element') + assert not attributes.has_parent(Foo, b2, 'element') + assert not attributes.has_parent(Foo, f2, 'element') + assert attributes.has_parent(Bar, f2, 'element') + b2.element = None - assert not manager.get_history(b2, 'element').hasparent(f2) + assert not attributes.has_parent(Bar, f2, 'element') - def testmutablescalars(self): + # test that double assignment doesn't accidentally reset the 'parent' flag. + b3 = Bar() + f4 = Foo() + b3.element = f4 + assert attributes.has_parent(Bar, f4, 'element') + b3.element = f4 + assert attributes.has_parent(Bar, f4, 'element') + + def test_mutablescalars(self): """test detection of changes on mutable scalar items""" class Foo(object):pass - manager = attributes.AttributeManager() - manager.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True) + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False) x = Foo() - x.element = ['one', 'two', 'three'] - manager.commit(x) + x.element = ['one', 'two', 'three'] + x._state.commit_all() x.element[1] = 'five' - assert manager.is_modified(x) - - manager.reset_class_managed(Foo) - manager = attributes.AttributeManager() - manager.register_attribute(Foo, 'element', uselist=False) + assert x._state.is_modified() + + attributes.unregister_class(Foo) + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'element', uselist=False, useobject=False) x = Foo() - x.element = ['one', 'two', 'three'] - manager.commit(x) + x.element = ['one', 'two', 'three'] + x._state.commit_all() x.element[1] = 'five' - assert not manager.is_modified(x) - - def testdescriptorattributes(self): + assert not x._state.is_modified() + + def test_descriptorattributes(self): """changeset: 1633 broke ability to use ORM to map classes with unusual descriptor attributes (for example, classes that inherit from ones implementing zope.interface.Interface). @@ -362,21 +384,24 @@ class AttributesTest(PersistTest): class Foo(object): A = des() - manager = attributes.AttributeManager() - manager.reset_class_managed(Foo) - - def testcollectionclasses(self): - manager = attributes.AttributeManager() + + attributes.unregister_class(Foo) + + def test_collectionclasses(self): + class Foo(object):pass - manager.register_attribute(Foo, "collection", uselist=True, typecallable=set) + attributes.register_class(Foo) + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=set, useobject=True) assert isinstance(Foo().collection, set) - + + attributes.unregister_attribute(Foo, "collection") + try: - manager.register_attribute(Foo, "collection", uselist=True, typecallable=dict) + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=dict, useobject=True) assert False except exceptions.ArgumentError, e: assert str(e) == "Type InstrumentedDict must elect an appender method to be a collection class" - + class MyDict(dict): @collection.appender def append(self, item): @@ -384,16 +409,18 @@ class AttributesTest(PersistTest): @collection.remover def remove(self, item): del self[item.foo] - manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyDict) + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyDict, useobject=True) assert isinstance(Foo().collection, MyDict) - + + attributes.unregister_attribute(Foo, "collection") + class MyColl(object):pass try: - manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl) + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True) assert False except exceptions.ArgumentError, e: assert str(e) == "Type MyColl must elect an appender method to be a collection class" - + class MyColl(object): @collection.iterator def __iter__(self): @@ -404,12 +431,685 @@ class AttributesTest(PersistTest): @collection.remover def remove(self, item): pass - manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl) + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True) try: Foo().collection assert True except exceptions.ArgumentError, e: assert False - + + +class BackrefTest(TestBase): + + def test_manytomany(self): + class Student(object):pass + class Course(object):pass + + attributes.register_class(Student) + attributes.register_class(Course) + attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True) + attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True) + + s = Student() + c = Course() + s.courses.append(c) + self.assert_(c.students == [s]) + s.courses.remove(c) + self.assert_(c.students == []) + + (s1, s2, s3) = (Student(), Student(), Student()) + + c.students = [s1, s2, s3] + self.assert_(s2.courses == [c]) + self.assert_(s1.courses == [c]) + s1.courses.remove(c) + self.assert_(c.students == [s2,s3]) + + def test_onetomany(self): + class Post(object):pass + class Blog(object):pass + + attributes.register_class(Post) + attributes.register_class(Blog) + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) + b = Blog() + (p1, p2, p3) = (Post(), Post(), Post()) + b.posts.append(p1) + b.posts.append(p2) + b.posts.append(p3) + self.assert_(b.posts == [p1, p2, p3]) + self.assert_(p2.blog is b) + + p3.blog = None + self.assert_(b.posts == [p1, p2]) + p4 = Post() + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + p4.blog = b + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + # assert no failure removing None + p5 = Post() + p5.blog = None + del p5.blog + + def test_onetoone(self): + class Port(object):pass + class Jack(object):pass + attributes.register_class(Port) + attributes.register_class(Jack) + attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True) + attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True) + p = Port() + j = Jack() + p.jack = j + self.assert_(j.port is p) + self.assert_(p.jack is not None) + + j.port = None + self.assert_(p.jack is None) + +class DeferredBackrefTest(TestBase): + def setUp(self): + global Post, Blog, called, lazy_load + + class Post(object): + def __init__(self, name): + self.name = name + def __eq__(self, other): + return other.name == self.name + + class Blog(object): + def __init__(self, name): + self.name = name + def __eq__(self, other): + return other.name == self.name + + called = [0] + + lazy_load = [] + def lazy_posts(instance): + def load(): + called[0] += 1 + return lazy_load + return load + + attributes.register_class(Post) + attributes.register_class(Blog) + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True) + + def test_lazy_add(self): + global lazy_load + + p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3") + lazy_load = [p1, p2, p3] + + b = Blog("blog 1") + p = Post("post 4") + p.blog = b + p = Post("post 5") + p.blog = b + # setting blog doesnt call 'posts' callable + assert called[0] == 0 + + # calling backref calls the callable, populates extra posts + assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")] + assert called[0] == 1 + + def test_lazy_remove(self): + global lazy_load + called[0] = 0 + lazy_load = [] + + b = Blog("blog 1") + p = Post("post 1") + p.blog = b + assert called[0] == 0 + + lazy_load = [p] + + p.blog = None + p2 = Post("post 2") + p2.blog = b + assert called[0] == 0 + assert b.posts == [p2] + assert called[0] == 1 + + def test_normal_load(self): + global lazy_load + lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + called[0] = 0 + + b = Blog("blog 1") + + # assign without using backref system + p2.__dict__['blog'] = b + + assert b.posts == [Post("post 1"), Post("post 2"), Post("post 3")] + assert called[0] == 1 + p2.blog = None + p4 = Post("post 4") + p4.blog = b + assert b.posts == [Post("post 1"), Post("post 3"), Post("post 4")] + assert called[0] == 1 + + called[0] = 0 + lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + +class HistoryTest(TestBase): + def test_get_committed_value(self): + class Foo(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) + + f = Foo() + self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None) + + f.someattr = 3 + self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None) + + f = Foo() + f.someattr = 3 + self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None) + + f._state.commit(['someattr']) + self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), 3) + + def test_scalar(self): + class Foo(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) + + # case 1. new object + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr = "hi" + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['hi'], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['hi'], [])) + + f.someattr = 'there' + + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['there'], [], ['hi'])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['there'], [])) + + del f.someattr + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], ['there'])) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = 'new' + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], [])) + + f.someattr = 'old' + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['old'], [], ['new'])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['old'], [])) + + # setting None on uninitialized is currently a change for a scalar attribute + # no lazyload occurs so this allows overwrite operation to proceed + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + f.someattr = None + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [])) + + f = Foo() + f.__dict__['someattr'] = 'new' + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], [])) + f.someattr = None + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], ['new'])) + + def test_mutable_scalar(self): + class Foo(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False, mutable_scalars=True, copy_function=dict) + + # case 1. new object + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr = {'foo':'hi'} + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'hi'}], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'hi'}], [])) + self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'}) + + f.someattr['foo'] = 'there' + self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'}) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'there'}], [], [{'foo':'hi'}])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'there'}], [])) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = {'foo':'new'} + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'new'}], [])) + + f.someattr = {'foo':'old'} + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'old'}], [], [{'foo':'new'}])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'old'}], [])) + + + def test_use_object(self): + class Foo(fixtures.Base): + pass + + class Bar(fixtures.Base): + _state = None + def __nonzero__(self): + assert False + + hi = Bar(name='hi') + there = Bar(name='there') + new = Bar(name='new') + old = Bar(name='old') + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=True) + + # case 1. new object + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], [])) + + f.someattr = hi + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [])) + + f.someattr = there + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], [])) + + del f.someattr + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [there])) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = new + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], [])) + + f.someattr = old + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], [])) + + # setting None on uninitialized is currently not a change for an object attribute + # (this is different than scalar attribute). a lazyload has occured so if its + # None, its really None + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], [])) + f.someattr = None + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], [])) + + f = Foo() + f.__dict__['someattr'] = new + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], [])) + f.someattr = None + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [new])) + + def test_object_collections_set(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + def __nonzero__(self): + assert False + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True) + + hi = Bar(name='hi') + there = Bar(name='there') + old = Bar(name='old') + new = Bar(name='new') + + # case 1. new object + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr = [hi] + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [])) + + f.someattr = [there] + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], [])) + + f.someattr = [hi] + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [there])) + + f.someattr = [old, new] + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [], [there])) + + # case 2. object with direct settings (similar to a load operation) + f = Foo() + collection = attributes.init_collection(f, 'someattr') + collection.append_without_event(new) + f._state.commit_all() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], [])) + + f.someattr = [old] + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], [])) + + def test_dict_collections(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + from sqlalchemy.orm.collections import attribute_mapped_collection + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True, typecallable=attribute_mapped_collection('name')) + + hi = Bar(name='hi') + there = Bar(name='there') + old = Bar(name='old') + new = Bar(name='new') + + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr['hi'] = hi + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [])) + + f.someattr['there'] = there + self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([hi, there]), set([]), set([]))) + + f._state.commit(['someattr']) + self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([]), set([hi, there]), set([]))) + + def test_object_collections_mutate(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True) + attributes.register_attribute(Foo, 'id', uselist=False, useobject=False) + + hi = Bar(name='hi') + there = Bar(name='there') + old = Bar(name='old') + new = Bar(name='new') + + # case 1. new object + f = Foo(id=1) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr.append(hi) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [])) + + f.someattr.append(there) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [hi], [])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, there], [])) + + f.someattr.remove(there) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [there])) + + f.someattr.append(old) + f.someattr.append(new) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [hi], [there])) + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, old, new], [])) + + f.someattr.pop(0) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old, new], [hi])) + + # case 2. object with direct settings (similar to a load operation) + f = Foo() + f.__dict__['id'] = 1 + collection = attributes.init_collection(f, 'someattr') + collection.append_without_event(new) + f._state.commit_all() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], [])) + + f.someattr.append(old) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [new], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new, old], [])) + + f = Foo() + collection = attributes.init_collection(f, 'someattr') + collection.append_without_event(new) + f._state.commit_all() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], [])) + + f.id = 1 + f.someattr.remove(new) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [new])) + + # case 3. mixing appends with sets + f = Foo() + f.someattr.append(hi) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [])) + f.someattr.append(there) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi, there], [], [])) + f.someattr = [there] + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [])) + + def test_collections_via_backref(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, useobject=True) + attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) + + f1 = Foo() + b1 = Bar() + self.assertEquals(attributes.get_history(f1._state, 'bars'), ([], [], [])) + self.assertEquals(attributes.get_history(b1._state, 'foo'), ([], [None], [])) + + #b1.foo = f1 + f1.bars.append(b1) + self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1], [], [])) + self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], [])) + + b2 = Bar() + f1.bars.append(b2) + self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1, b2], [], [])) + self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], [])) + self.assertEquals(attributes.get_history(b2._state, 'foo'), ([f1], [], [])) + + def test_lazy_backref_collections(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + lazy_load = [] + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, callable_=lazyload, useobject=True) + attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) + + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] + lazy_load = [bar1, bar2, bar3] + + f = Foo() + bar4 = Bar() + bar4.foo = f + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], [])) + + lazy_load = None + f = Foo() + bar4 = Bar() + bar4.foo = f + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], [])) + + lazy_load = [bar1, bar2, bar3] + f._state.expire_attributes(['bars']) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], [])) + + def test_collections_via_lazyload(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + lazy_load = [] + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lazyload, trackparent=True, useobject=True) + + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] + lazy_load = [bar1, bar2, bar3] + + f = Foo() + f.bars = [] + self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [], [bar1, bar2, bar3])) + + f = Foo() + f.bars.append(bar4) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []) ) + + f = Foo() + f.bars.remove(bar2) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2])) + f.bars.append(bar4) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar3], [bar2])) + + f = Foo() + del f.bars[1] + self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2])) + + lazy_load = None + f = Foo() + f.bars.append(bar2) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar2], [], [])) + + def test_scalar_via_lazyload(self): + class Foo(fixtures.Base): + pass + + lazy_load = None + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, useobject=False) + lazy_load = "hi" + + # with scalar non-object, the lazy callable is only executed on gets, not history + # operations + + f = Foo() + self.assertEquals(f.bar, "hi") + self.assertEquals(attributes.get_history(f._state, 'bar'), ([], ["hi"], [])) + + f = Foo() + f.bar = None + self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [])) + + f = Foo() + f.bar = "there" + self.assertEquals(attributes.get_history(f._state, 'bar'), (["there"], [], [])) + f.bar = "hi" + self.assertEquals(attributes.get_history(f._state, 'bar'), (["hi"], [], [])) + + f = Foo() + self.assertEquals(f.bar, "hi") + del f.bar + self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [], ["hi"])) + assert f.bar is None + self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], ["hi"])) + + def test_scalar_object_via_lazyload(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + lazy_load = None + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, trackparent=True, useobject=True) + bar1, bar2 = [Bar(id=1), Bar(id=2)] + lazy_load = bar1 + + # with scalar object, the lazy callable is only executed on gets and history + # operations + + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], [])) + + f = Foo() + f.bar = None + self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1])) + + f = Foo() + f.bar = bar2 + self.assertEquals(attributes.get_history(f._state, 'bar'), ([bar2], [], [bar1])) + f.bar = bar1 + self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], [])) + + f = Foo() + self.assertEquals(f.bar, bar1) + del f.bar + self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1])) + assert f.bar is None + self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1])) + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/cascade.py b/test/orm/cascade.py index b832c427e0..7a68a4d58a 100644 --- a/test/orm/cascade.py +++ b/test/orm/cascade.py @@ -1,215 +1,195 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions from sqlalchemy.orm import * -from sqlalchemy.ext.sessioncontext import SessionContext from testlib import * -import testlib.tables as tables +from testlib import fixtures -class O2MCascadeTest(AssertMixin): - def tearDown(self): - tables.delete() +class O2MCascadeTest(fixtures.FixtureTest): + keep_mappers = True + keep_data = False + refresh_data = False - def tearDownAll(self): - clear_mappers() - tables.drop() - - def setUpAll(self): - global data - tables.create() - mapper(tables.User, tables.users, properties = dict( - address = relation(mapper(tables.Address, tables.addresses), lazy=True, uselist = False, cascade="all, delete-orphan"), + def setup_mappers(self): + global User, Address, Order, users, orders, addresses + from testlib.fixtures import User, Address, Order, users, orders, addresses + + mapper(Address, addresses) + mapper(User, users, properties = dict( + addresses = relation(Address, cascade="all, delete-orphan"), orders = relation( - mapper(tables.Order, tables.orders, properties = dict ( - items = relation(mapper(tables.Item, tables.orderitems), lazy=True, uselist =True, cascade="all, delete-orphan") - )), - lazy = True, uselist = True, cascade="all, delete-orphan") + mapper(Order, orders), cascade="all, delete-orphan") )) - - def setUp(self): - global data - data = [tables.User, - {'user_name' : 'ed', - 'address' : (tables.Address, {'email_address' : 'foo@bar.com'}), - 'orders' : (tables.Order, [ - {'description' : 'eds 1st order', 'items' : (tables.Item, [{'item_name' : 'eds o1 item'}, {'item_name' : 'eds other o1 item'}])}, - {'description' : 'eds 2nd order', 'items' : (tables.Item, [{'item_name' : 'eds o2 item'}, {'item_name' : 'eds other o2 item'}])} - ]) - }, - {'user_name' : 'jack', - 'address' : (tables.Address, {'email_address' : 'jack@jack.com'}), - 'orders' : (tables.Order, [ - {'description' : 'jacks 1st order', 'items' : (tables.Item, [{'item_name' : 'im a lumberjack'}, {'item_name' : 'and im ok'}])} - ]) - }, - {'user_name' : 'foo', - 'address' : (tables.Address, {'email_address': 'hi@lala.com'}), - 'orders' : (tables.Order, [ - {'description' : 'foo order', 'items' : (tables.Item, [])}, - {'description' : 'foo order 2', 'items' : (tables.Item, [{'item_name' : 'hi'}])}, - {'description' : 'foo order three', 'items' : (tables.Item, [{'item_name' : 'there'}])} - ]) - } - ] - - sess = create_session() - for elem in data[1:]: - u = tables.User() - sess.save(u) - u.user_name = elem['user_name'] - u.address = tables.Address() - u.address.email_address = elem['address'][1]['email_address'] - u.orders = [] - for order in elem['orders'][1]: - o = tables.Order() - o.isopen = None - o.description = order['description'] - u.orders.append(o) - o.items = [] - for item in order['items'][1]: - i = tables.Item() - i.item_name = item['item_name'] - o.items.append(i) - - sess.flush() - sess.clear() - - def testassignlist(self): + + def test_list_assignment(self): sess = create_session() - u = tables.User() - u.user_name = 'jack' - o1 = tables.Order() - o1.description ='someorder' - o2 = tables.Order() - o2.description = 'someotherorder' - l = [o1, o2] + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) sess.save(u) - u.orders = l - assert o1 in sess - assert o2 in sess sess.flush() sess.clear() - u = sess.query(tables.User).get(u.user_id) - o3 = tables.Order() - o3.description='order3' - o4 = tables.Order() - o4.description = 'order4' - u.orders = [o3, o4] - assert o3 in sess - assert o4 in sess + u = sess.query(User).get(u.id) + self.assertEquals(u, User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')])) + + u.orders=[Order(description="order 3"), Order(description="order 4")] sess.flush() + sess.clear() - o5 = tables.Order() - o5.description='order5' + u = sess.query(User).get(u.id) + self.assertEquals(u, User(name='jack', orders=[Order(description="order 3"), Order(description="order 4")])) + + self.assertEquals(sess.query(Order).all(), [Order(description="order 3"), Order(description="order 4")]) + o5 = Order(description="order 5") sess.save(o5) try: sess.flush() assert False except exceptions.FlushError, e: assert "is an orphan" in str(e) - - - def testdelete(self): - sess = create_session() - l = sess.query(tables.User).select() - for u in l: - print repr(u.orders) - self.assert_result(l, data[0], *data[1:]) - ids = (l[0].user_id, l[2].user_id) - sess.delete(l[0]) - sess.delete(l[2]) + def test_delete(self): + sess = create_session() + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) + sess.save(u) + sess.flush() + sess.delete(u) sess.flush() - assert tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 0 - assert tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0 - assert tables.addresses.count(tables.addresses.c.user_id.in_(*ids)).scalar() == 0 - assert tables.users.count(tables.users.c.user_id.in_(*ids)).scalar() == 0 + assert users.count().scalar() == 0 + assert orders.count().scalar() == 0 - def testdelete2(self): + def test_delete_unloaded_collections(self): """test that unloaded collections are still included in a delete-cascade by default.""" - + sess = create_session() - u = sess.query(tables.User).get_by(user_name='ed') - # assert 'addresses' collection not loaded + u = User(name='jack', addresses=[Address(email_address="address1"), Address(email_address="address2")]) + sess.save(u) + sess.flush() + sess.clear() + assert addresses.count().scalar() == 2 + assert users.count().scalar() == 1 + + u = sess.query(User).get(u.id) + assert 'addresses' not in u.__dict__ sess.delete(u) sess.flush() - assert tables.addresses.count(tables.addresses.c.email_address=='foo@bar.com').scalar() == 0 - assert tables.orderitems.count(tables.orderitems.c.item_name.like('eds%')).scalar() == 0 + assert addresses.count().scalar() == 0 + assert users.count().scalar() == 0 - def testcascadecollection(self): + def test_cascades_onlycollection(self): """test that cascade only reaches instances that are still part of the collection, not those that have been removed""" + sess = create_session() - - u = tables.User() - u.user_name = 'newuser' - o = tables.Order() - o.description = "some description" - u.orders.append(o) + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) sess.save(u) sess.flush() - u.orders.remove(o) + o = u.orders[0] + del u.orders[0] sess.delete(u) assert u in sess.deleted assert o not in sess.deleted - + assert o in sess + + u2 = User(name='newuser', orders=[o]) + sess.save(u2) + sess.flush() + sess.clear() + assert users.count().scalar() == 1 + assert orders.count().scalar() == 1 + self.assertEquals(sess.query(User).all(), [User(name='newuser', orders=[Order(description='someorder')])]) - def testorphan(self): + def test_cascade_delete_plusorphans(self): sess = create_session() - l = sess.query(tables.User).select() - jack = l[1] - jack.orders[:] = [] + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) + sess.save(u) + sess.flush() + assert users.count().scalar() == 1 + assert orders.count().scalar() == 2 + + del u.orders[0] + sess.delete(u) + sess.flush() + assert users.count().scalar() == 0 + assert orders.count().scalar() == 0 + + def test_collection_orphans(self): + sess = create_session() + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) + sess.save(u) + sess.flush() - ids = [jack.user_id] - self.assert_(tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 1) - self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 2) + assert users.count().scalar() == 1 + assert orders.count().scalar() == 2 + + u.orders[:] = [] sess.flush() - self.assert_(tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 0) - self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0) + assert users.count().scalar() == 1 + assert orders.count().scalar() == 0 +class O2MCascadeNoOrphanTest(fixtures.FixtureTest): + keep_mappers = True + keep_data = False + refresh_data = False -class M2OCascadeTest(AssertMixin): - def tearDown(self): - ctx.current.clear() - for t in metadata.table_iterator(reverse=True): - t.delete().execute() - - def tearDownAll(self): - clear_mappers() - metadata.drop_all() + def setup_mappers(self): + global User, Address, Order, users, orders, addresses + from testlib.fixtures import User, Address, Order, users, orders, addresses + + mapper(User, users, properties = dict( + orders = relation( + mapper(Order, orders), cascade="all") + )) + + def test_cascade_delete_noorphans(self): + sess = create_session() + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) + sess.save(u) + sess.flush() + assert users.count().scalar() == 1 + assert orders.count().scalar() == 2 - def setUpAll(self): - global ctx, data, metadata, User, Pref, Extra - ctx = SessionContext(create_session) - metadata = MetaData(testbase.db) - extra = Table("extra", metadata, - Column("extra_id", Integer, Sequence("extra_id_seq", optional=True), primary_key=True), - Column("prefs_id", Integer, ForeignKey("prefs.prefs_id")) + del u.orders[0] + sess.delete(u) + sess.flush() + assert users.count().scalar() == 0 + assert orders.count().scalar() == 1 + + +class M2OCascadeTest(ORMTest): + keep_mappers = True + + def define_tables(self, metadata): + global extra, prefs, users + + extra = Table("extra", metadata, + Column("id", Integer, Sequence("extra_id_seq", optional=True), primary_key=True), + Column("prefs_id", Integer, ForeignKey("prefs.id")) ) - prefs = Table('prefs', metadata, - Column('prefs_id', Integer, Sequence('prefs_id_seq', optional=True), primary_key=True), - Column('prefs_data', String(40))) - + prefs = Table('prefs', metadata, + Column('id', Integer, Sequence('prefs_id_seq', optional=True), primary_key=True), + Column('data', String(40))) + users = Table('users', metadata, - Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), - Column('user_name', String(40)), - Column('pref_id', Integer, ForeignKey('prefs.prefs_id')) + Column('id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), + Column('name', String(40)), + Column('pref_id', Integer, ForeignKey('prefs.id')) ) - class User(object): - def __init__(self, name): - self.user_name = name - class Pref(object): - def __init__(self, data): - self.prefs_data = data - class Extra(object): + + def setup_mappers(self): + global User, Pref, Extra + class User(fixtures.Base): + pass + class Pref(fixtures.Base): pass - metadata.create_all() + class Extra(fixtures.Base): + pass + mapper(Extra, extra) mapper(Pref, prefs, properties=dict( extra = relation(Extra, cascade="all, delete") @@ -219,118 +199,342 @@ class M2OCascadeTest(AssertMixin): )) def setUp(self): - u1 = User("ed") - u1.pref = Pref("pref 1") - u2 = User("jack") - u2.pref = Pref("pref 2") - u3 = User("foo") - u3.pref = Pref("pref 3") - u1.pref.extra.append(Extra()) - u2.pref.extra.append(Extra()) - u2.pref.extra.append(Extra()) - - ctx.current.save(u1) - ctx.current.save(u2) - ctx.current.save(u3) - ctx.current.flush() - ctx.current.clear() - - def testorphan(self): - jack = ctx.current.query(User).get_by(user_name='jack') - p = jack.pref - e = jack.pref.extra[0] + u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()])) + u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()])) + u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()])) + sess = create_session() + sess.save(u1) + sess.save(u2) + sess.save(u3) + sess.flush() + sess.close() + + @testing.fails_on('maxdb') + def test_orphan(self): + sess = create_session() + assert prefs.count().scalar() == 3 + assert extra.count().scalar() == 3 + jack = sess.query(User).filter_by(name="jack").one() jack.pref = None - ctx.current.flush() - assert p not in ctx.current - assert e not in ctx.current + sess.flush() + assert prefs.count().scalar() == 2 + assert extra.count().scalar() == 2 - def testorphan2(self): - jack = ctx.current.query(User).get_by(user_name='jack') + @testing.fails_on('maxdb') + def test_orphan_on_update(self): + sess = create_session() + jack = sess.query(User).filter_by(name="jack").one() p = jack.pref e = jack.pref.extra[0] - ctx.current.clear() + sess.clear() jack.pref = None - ctx.current.update(jack) - ctx.current.update(p) - ctx.current.update(e) - assert p in ctx.current - assert e in ctx.current - ctx.current.flush() - assert p not in ctx.current - assert e not in ctx.current + sess.update(jack) + sess.update(p) + sess.update(e) + assert p in sess + assert e in sess + sess.flush() + assert prefs.count().scalar() == 2 + assert extra.count().scalar() == 2 + def test_pending_expunge(self): + sess = create_session() + someuser = User(name='someuser') + sess.save(someuser) + sess.flush() + someuser.pref = p1 = Pref(data='somepref') + assert p1 in sess + someuser.pref = Pref(data='someotherpref') + assert p1 not in sess + sess.flush() + self.assertEquals(sess.query(Pref).with_parent(someuser).all(), [Pref(data="someotherpref")]) + + def test_double_assignment(self): + """test that double assignment doesn't accidentally reset the 'parent' flag.""" -class M2MCascadeTest(AssertMixin): - def setUpAll(self): - global metadata, a, b, atob - metadata = MetaData(testbase.db) - a = Table('a', metadata, + sess = create_session() + jack = sess.query(User).filter_by(name="jack").one() + + newpref = Pref(data="newpref") + jack.pref = newpref + jack.pref = newpref + sess.flush() + self.assertEquals(sess.query(Pref).all(), [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) + +class M2OCascadeDeleteTest(ORMTest): + keep_mappers = True + + def define_tables(self, metadata): + global t1, t2, t3 + t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), Column('t2id', Integer, ForeignKey('t2.id'))) + t2 = Table('t2', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), Column('t3id', Integer, ForeignKey('t3.id'))) + t3 = Table('t3', metadata, Column('id', Integer, primary_key=True), Column('data', String(50))) + + def setup_mappers(self): + global T1, T2, T3 + class T1(fixtures.Base):pass + class T2(fixtures.Base):pass + class T3(fixtures.Base):pass + + mapper(T1, t1, properties={'t2':relation(T2, cascade="all")}) + mapper(T2, t2, properties={'t3':relation(T3, cascade="all")}) + mapper(T3, t3) + + def test_cascade_delete(self): + sess = create_session() + + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.save(x) + sess.flush() + + sess.delete(x) + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), []) + self.assertEquals(sess.query(T3).all(), []) + + def test_cascade_delete_postappend_onelevel(self): + sess = create_session() + + x1 = T1(data='t1', ) + x2 = T2(data='t2') + x3 = T3(data='t3') + sess.save(x1) + sess.save(x2) + sess.save(x3) + sess.flush() + + sess.delete(x1) + x1.t2 = x2 + x2.t3 = x3 + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), []) + self.assertEquals(sess.query(T3).all(), []) + + def test_cascade_delete_postappend_twolevel(self): + sess = create_session() + + x1 = T1(data='t1', t2=T2(data='t2')) + x3 = T3(data='t3') + sess.save(x1) + sess.save(x3) + sess.flush() + + sess.delete(x1) + x1.t2.t3 = x3 + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), []) + self.assertEquals(sess.query(T3).all(), []) + + def test_preserves_orphans_onelevel(self): + sess = create_session() + + x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + sess.save(x2) + sess.flush() + x2.t2 = None + + sess.delete(x2) + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), [T2()]) + self.assertEquals(sess.query(T3).all(), [T3()]) + + @testing.future + def test_preserves_orphans_onelevel_postremove(self): + sess = create_session() + + x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + sess.save(x2) + sess.flush() + + sess.delete(x2) + x2.t2 = None + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), [T2()]) + self.assertEquals(sess.query(T3).all(), [T3()]) + + def test_preserves_orphans_twolevel(self): + sess = create_session() + + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.save(x) + sess.flush() + + x.t2.t3 = None + sess.delete(x) + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), []) + self.assertEquals(sess.query(T3).all(), [T3()]) + +class M2OCascadeDeleteOrphanTest(ORMTest): + keep_mappers = True + + def define_tables(self, metadata): + global t1, t2, t3 + t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), Column('t2id', Integer, ForeignKey('t2.id'))) + t2 = Table('t2', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), Column('t3id', Integer, ForeignKey('t3.id'))) + t3 = Table('t3', metadata, Column('id', Integer, primary_key=True), Column('data', String(50))) + + def setup_mappers(self): + global T1, T2, T3 + class T1(fixtures.Base):pass + class T2(fixtures.Base):pass + class T3(fixtures.Base):pass + + mapper(T1, t1, properties={'t2':relation(T2, cascade="all, delete-orphan")}) + mapper(T2, t2, properties={'t3':relation(T3, cascade="all, delete-orphan")}) + mapper(T3, t3) + + def test_cascade_delete(self): + sess = create_session() + + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.save(x) + sess.flush() + + sess.delete(x) + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), []) + self.assertEquals(sess.query(T3).all(), []) + + def test_deletes_orphans_onelevel(self): + sess = create_session() + + x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + sess.save(x2) + sess.flush() + x2.t2 = None + + sess.delete(x2) + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), []) + self.assertEquals(sess.query(T3).all(), []) + + def test_deletes_orphans_twolevel(self): + sess = create_session() + + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.save(x) + sess.flush() + + x.t2.t3 = None + sess.delete(x) + sess.flush() + self.assertEquals(sess.query(T1).all(), []) + self.assertEquals(sess.query(T2).all(), []) + self.assertEquals(sess.query(T3).all(), []) + + def test_finds_orphans_twolevel(self): + sess = create_session() + + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.save(x) + sess.flush() + + x.t2.t3 = None + sess.flush() + self.assertEquals(sess.query(T1).all(), [T1()]) + self.assertEquals(sess.query(T2).all(), [T2()]) + self.assertEquals(sess.query(T3).all(), []) + +class M2MCascadeTest(ORMTest): + def define_tables(self, metadata): + global a, b, atob, c + a = Table('a', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)) ) - b = Table('b', metadata, + b = Table('b', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)) ) - atob = Table('atob', metadata, + atob = Table('atob', metadata, Column('aid', Integer, ForeignKey('a.id')), Column('bid', Integer, ForeignKey('b.id')) - ) - metadata.create_all() - - def tearDownAll(self): - metadata.drop_all() - - def testdeleteorphan(self): - class A(object): - def __init__(self, data): - self.data = data - class B(object): - def __init__(self, data): - self.data = data - + c = Table('c', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('bid', Integer, ForeignKey('b.id')) + ) + + def test_delete_orphan(self): + class A(fixtures.Base): + pass + class B(fixtures.Base): + pass + mapper(A, a, properties={ # if no backref here, delete-orphan failed until [ticket:427] was fixed 'bs':relation(B, secondary=atob, cascade="all, delete-orphan") }) mapper(B, b) - + sess = create_session() - a1 = A('a1') - b1 = B('b1') - a1.bs.append(b1) + b1 = B(data='b1') + a1 = A(data='a1', bs=[b1]) sess.save(a1) sess.flush() - + a1.bs.remove(b1) sess.flush() assert atob.count().scalar() ==0 assert b.count().scalar() == 0 assert a.count().scalar() == 1 - - def testcascadedelete(self): - class A(object): - def __init__(self, data): - self.data = data - class B(object): - def __init__(self, data): - self.data = data + + def test_delete_orphan_cascades(self): + class A(fixtures.Base): + pass + class B(fixtures.Base): + pass + class C(fixtures.Base): + pass + + mapper(A, a, properties={ + # if no backref here, delete-orphan failed until [ticket:427] was fixed + 'bs':relation(B, secondary=atob, cascade="all, delete-orphan") + }) + mapper(B, b, properties={'cs':relation(C, cascade="all, delete-orphan")}) + mapper(C, c) + + sess = create_session() + b1 = B(data='b1', cs=[C(data='c1')]) + a1 = A(data='a1', bs=[b1]) + sess.save(a1) + sess.flush() + + a1.bs.remove(b1) + sess.flush() + assert atob.count().scalar() ==0 + assert b.count().scalar() == 0 + assert a.count().scalar() == 1 + assert c.count().scalar() == 0 + def test_cascade_delete(self): + class A(fixtures.Base): + pass + class B(fixtures.Base): + pass + mapper(A, a, properties={ 'bs':relation(B, secondary=atob, cascade="all, delete-orphan") }) mapper(B, b) - + sess = create_session() - a1 = A('a1') - b1 = B('b1') - a1.bs.append(b1) + a1 = A(data='a1', bs=[B(data='b1')]) sess.save(a1) sess.flush() - + sess.delete(a1) sess.flush() assert atob.count().scalar() ==0 @@ -339,12 +543,12 @@ class M2MCascadeTest(AssertMixin): class UnsavedOrphansTest(ORMTest): """tests regarding pending entities that are orphans""" - + def define_tables(self, metadata): global users, addresses, User, Address users = Table('users', metadata, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), - Column('user_name', String(40)), + Column('name', String(40)), ) addresses = Table('email_addresses', metadata, @@ -352,10 +556,10 @@ class UnsavedOrphansTest(ORMTest): Column('user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(40)), ) - class User(object):pass - class Address(object):pass - - def test_pending_orphan(self): + class User(fixtures.Base):pass + class Address(fixtures.Base):pass + + def test_pending_standalone_orphan(self): """test that an entity that never had a parent on a delete-orphan cascade cant be saved.""" mapper(Address, addresses) @@ -371,9 +575,8 @@ class UnsavedOrphansTest(ORMTest): pass assert a.address_id is None, "Error: address should not be persistent" - def test_delete_new_object(self): - """test that an entity which is attached then detached from its - parent with a delete-orphan cascade gets counted as an orphan""" + def test_pending_collection_expunge(self): + """test that removing a pending item from a collection expunges it from the session.""" mapper(Address, addresses) mapper(User, users, properties=dict( @@ -383,20 +586,35 @@ class UnsavedOrphansTest(ORMTest): u = User() s.save(u) + s.flush() a = Address() - assert a not in s.new + u.addresses.append(a) + assert a in s + u.addresses.remove(a) - s.delete(u) - try: - s.flush() # (erroneously) causes "a" to be persisted - assert False - except exceptions.FlushError: - assert True - assert u.user_id is None, "Error: user should not be persistent" - assert a.address_id is None, "Error: address should not be persistent" + assert a not in s + s.delete(u) + s.flush() + assert a.address_id is None, "Error: address should not be persistent" + + def test_nonorphans_ok(self): + mapper(Address, addresses) + mapper(User, users, properties=dict( + addresses=relation(Address, cascade="all,delete", backref="user") + )) + s = create_session() + u = User(name='u1', addresses=[Address(email_address='ad1')]) + s.save(u) + a1 = u.addresses[0] + u.addresses.remove(a1) + assert a1 in s + s.flush() + s.clear() + self.assertEquals(s.query(Address).all(), [Address(email_address='ad1')]) + class UnsavedOrphansTest2(ORMTest): """same test as UnsavedOrphans only three levels deep""" @@ -420,56 +638,107 @@ class UnsavedOrphansTest2(ORMTest): ) - def testdeletechildwithchild(self): - """test that an entity which is attached then detached from its - parent with a delete-orphan cascade gets counted as an orphan, as well - as its own child instances""" - - class Order(object): pass - class Item(object): pass - class Attribute(object): pass + def test_pending_expunge(self): + class Order(fixtures.Base): + pass + class Item(fixtures.Base): + pass + class Attribute(fixtures.Base): + pass - attrMapper = mapper(Attribute, attributes) - itemMapper = mapper(Item, items, properties=dict( - attributes=relation(attrMapper, cascade="all,delete-orphan", backref="item") + mapper(Attribute, attributes) + mapper(Item, items, properties=dict( + attributes=relation(Attribute, cascade="all,delete-orphan", backref="item") )) - orderMapper = mapper(Order, orders, properties=dict( - items=relation(itemMapper, cascade="all,delete-orphan", backref="order") + mapper(Order, orders, properties=dict( + items=relation(Item, cascade="all,delete-orphan", backref="order") )) - s = create_session( ) - order = Order() + s = create_session() + order = Order(name="order1") s.save(order) - item = Item() - attr = Attribute() - item.attributes.append(attr) + attr = Attribute(name="attr1") + item = Item(name="item1", attributes=[attr]) order.items.append(item) - order.items.remove(item) # item is an orphan, but attr is not so flush() tries to save attr - try: - s.flush() - assert False - except exceptions.FlushError, e: - print e - assert True + order.items.remove(item) + + assert item not in s + assert attr not in s + + s.flush() + assert orders.count().scalar() == 1 + assert items.count().scalar() == 0 + assert attributes.count().scalar() == 0 + +class UnsavedOrphansTest3(ORMTest): + """test not expuning double parents""" + + def define_tables(self, meta): + global sales_reps, accounts, customers + sales_reps = Table('sales_reps', meta, + Column('sales_rep_id', Integer, Sequence('sales_rep_id_seq'), primary_key = True), + Column('name', String(50)), + ) + accounts = Table('accounts', meta, + Column('account_id', Integer, Sequence('account_id_seq'), primary_key = True), + Column('balance', Integer), + ) + customers = Table('customers', meta, + Column('customer_id', Integer, Sequence('customer_id_seq'), primary_key = True), + Column('name', String(50)), + Column('sales_rep_id', Integer, ForeignKey('sales_reps.sales_rep_id')), + Column('account_id', Integer, ForeignKey('accounts.account_id')), + ) + + def test_double_parent_expunge(self): + """test that removing a pending item from a collection expunges it from the session.""" + class Customer(fixtures.Base): + pass + class Account(fixtures.Base): + pass + class SalesRep(fixtures.Base): + pass + + mapper(Customer, customers) + mapper(Account, accounts, properties=dict( + customers=relation(Customer, cascade="all,delete-orphan", backref="account") + )) + mapper(SalesRep, sales_reps, properties=dict( + customers=relation(Customer, cascade="all,delete-orphan", backref="sales_rep") + )) + s = create_session() - assert item.id is None - assert attr.id is None + a = Account(balance=0) + sr = SalesRep(name="John") + [s.save(x) for x in [a,sr]] + s.flush() + + c = Customer(name="Jane") + + a.customers.append(c) + sr.customers.append(c) + assert c in s + + a.customers.remove(c) + assert c in s, "Should not expunge customer yet, still has one parent" -class DoubleParentOrphanTest(AssertMixin): + sr.customers.remove(c) + assert c not in s, "Should expunge customer when both parents are gone" + +class DoubleParentOrphanTest(ORMTest): """test orphan detection for an entity with two parent relations""" - - def setUpAll(self): - global metadata, address_table, businesses, homes - metadata = MetaData(testbase.db) + + def define_tables(self, metadata): + global address_table, businesses, homes address_table = Table('addresses', metadata, Column('address_id', Integer, primary_key=True), Column('street', String(30)), ) homes = Table('homes', metadata, - Column('home_id', Integer, primary_key=True), + Column('home_id', Integer, primary_key=True, key="id"), Column('description', String(30)), Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False), ) @@ -479,38 +748,42 @@ class DoubleParentOrphanTest(AssertMixin): Column('description', String(30), key="description"), Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False), ) - metadata.create_all() - def tearDown(self): - clear_mappers() - def tearDownAll(self): - metadata.drop_all() + def test_non_orphan(self): """test that an entity can have two parent delete-orphan cascades, and persists normally.""" + + class Address(fixtures.Base): + pass + class Home(fixtures.Base): + pass + class Business(fixtures.Base): + pass - class Address(object):pass - class Home(object):pass - class Business(object):pass mapper(Address, address_table) mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan")}) mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan")}) session = create_session() - a1 = Address() - a2 = Address() - h1 = Home() - b1 = Business() - h1.address = a1 - b1.address = a2 + h1 = Home(description='home1', address=Address(street='address1')) + b1 = Business(description='business1', address=Address(street='address2')) [session.save(x) for x in [h1,b1]] session.flush() + session.clear() + self.assertEquals(session.query(Home).get(h1.id), Home(description='home1', address=Address(street='address1'))) + self.assertEquals(session.query(Business).get(b1.id), Business(description='business1', address=Address(street='address2'))) + def test_orphan(self): """test that an entity can have two parent delete-orphan cascades, and is detected as an orphan when saved without a parent.""" - class Address(object):pass - class Home(object):pass - class Business(object):pass + class Address(fixtures.Base): + pass + class Home(fixtures.Base): + pass + class Business(fixtures.Base): + pass + mapper(Address, address_table) mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan")}) mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan")}) @@ -524,6 +797,47 @@ class DoubleParentOrphanTest(AssertMixin): except exceptions.FlushError, e: assert True +class CollectionAssignmentOrphanTest(ORMTest): + def define_tables(self, metadata): + global table_a, table_b + + table_a = Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30))) + table_b = Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30)), + Column('a_id', Integer, ForeignKey('a.id'))) + + def test_basic(self): + class A(fixtures.Base): + pass + class B(fixtures.Base): + pass + + mapper(A, table_a, properties={ + 'bs':relation(B, cascade="all, delete-orphan") + }) + mapper(B, table_b) + + a1 = A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]) + + sess = create_session() + sess.save(a1) + sess.flush() + + sess.clear() + + self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) + + a1 = sess.query(A).get(a1.id) + assert not class_mapper(B)._is_orphan(a1.bs[0]) + a1.bs[0].foo='b2modified' + a1.bs[1].foo='b3modified' + sess.flush() + + sess.clear() + self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/collection.py b/test/orm/collection.py index 1f4f649281..711dc730ba 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -1,4 +1,6 @@ -import testbase +import testenv; testenv.configure_for_tests() +import sys +from operator import and_ from sqlalchemy import * import sqlalchemy.exceptions as exceptions from sqlalchemy.orm import create_session, mapper, relation, \ @@ -6,9 +8,14 @@ from sqlalchemy.orm import create_session, mapper, relation, \ import sqlalchemy.orm.collections as collections from sqlalchemy.orm.collections import collection from sqlalchemy import util -from operator import and_ from testlib import * +try: + py_set = __builtins__.set +except AttributeError: + import sets + py_set = sets.Set + class Canary(interfaces.AttributeExtension): def __init__(self): self.data = set() @@ -35,7 +42,7 @@ class Entity(object): def __repr__(self): return str((id(self), self.a, self.b, self.c)) -manager = attributes.AttributeManager() +attributes.register_class(Entity) _id = 1 def entity_maker(): @@ -48,15 +55,16 @@ def dictable_entity(a=None, b=None, c=None): return Entity(a or str(_id), b or 'value %s' % _id, c) -class CollectionsTest(PersistTest): +class CollectionsTest(TestBase): def _test_adapter(self, typecallable, creator=entity_maker, to_set=None): class Foo(object): pass canary = Canary() - manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=typecallable) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable, useobject=True) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -73,12 +81,12 @@ class CollectionsTest(PersistTest): adapter.append_with_event(e1) assert_eq() - + adapter.append_without_event(e2) assert_ne() canary.data.add(e2) assert_eq() - + adapter.remove_without_event(e2) assert_ne() canary.data.remove(e2) @@ -90,10 +98,11 @@ class CollectionsTest(PersistTest): def _test_list(self, typecallable, creator=entity_maker): class Foo(object): pass - + canary = Canary() - manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=typecallable) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable, useobject=True) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -104,7 +113,7 @@ class CollectionsTest(PersistTest): self.assert_(set(direct) == canary.data) self.assert_(set(adapter) == canary.data) self.assert_(direct == control) - + # assume append() is available for list tests e = creator() direct.append(e) @@ -120,14 +129,14 @@ class CollectionsTest(PersistTest): e = creator() direct.append(e) control.append(e) - + e = creator() direct[0] = e control[0] = e assert_eq() if reduce(and_, [hasattr(direct, a) for a in - ('__delitem', 'insert', '__len__')], True): + ('__delitem__', 'insert', '__len__')], True): values = [creator(), creator(), creator(), creator()] direct[slice(0,1)] = values control[slice(0,1)] = values @@ -172,7 +181,7 @@ class CollectionsTest(PersistTest): e = creator() direct.append(e) control.append(e) - + direct.remove(e) control.remove(e) assert_eq() @@ -187,7 +196,22 @@ class CollectionsTest(PersistTest): direct[0:] = values control[0:] = values assert_eq() - + + values = [creator()] + direct[:1] = values + control[:1] = values + assert_eq() + + values = [creator()] + direct[-1::2] = values + control[-1::2] = values + assert_eq() + + values = [creator()] * len(direct[1::2]) + direct[1::2] = values + control[1::2] = values + assert_eq() + if hasattr(direct, '__delslice__'): for i in range(1, 4): e = creator() @@ -195,7 +219,7 @@ class CollectionsTest(PersistTest): control.append(e) del direct[-1:] - del control[-1:] + del control[-1:] assert_eq() del direct[1:2] @@ -213,13 +237,39 @@ class CollectionsTest(PersistTest): control.extend(values) assert_eq() + if hasattr(direct, '__iadd__'): + values = [creator(), creator(), creator()] + + direct += values + control += values + assert_eq() + + direct += [] + control += [] + assert_eq() + + values = [creator(), creator()] + obj.attr += values + control += values + assert_eq() + + if hasattr(direct, '__imul__'): + direct *= 2 + control *= 2 + assert_eq() + + obj.attr *= 2 + control *= 2 + assert_eq() + def _test_list_bulk(self, typecallable, creator=entity_maker): class Foo(object): pass canary = Canary() - manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=typecallable) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable, useobject=True) obj = Foo() direct = obj.attr @@ -238,7 +288,7 @@ class CollectionsTest(PersistTest): self.assert_(set(obj.attr) == set([e2])) self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) - + e3 = creator() real_list = [e3] obj.attr = real_list @@ -246,15 +296,30 @@ class CollectionsTest(PersistTest): self.assert_(set(obj.attr) == set([e3])) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) - + e4 = creator() try: obj.attr = set([e4]) self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(e4 not in canary.data) self.assert_(e3 in canary.data) + e5 = creator() + e6 = creator() + e7 = creator() + obj.attr = [e5, e6, e7] + self.assert_(e5 in canary.added) + self.assert_(e6 in canary.added) + self.assert_(e7 in canary.added) + + obj.attr = [e6, e7] + self.assert_(e5 in canary.removed) + self.assert_(e6 in canary.added) + self.assert_(e7 in canary.added) + self.assert_(e6 not in canary.removed) + self.assert_(e7 not in canary.removed) + def test_list(self): self._test_adapter(list) self._test_list(list) @@ -288,7 +353,7 @@ class CollectionsTest(PersistTest): return self.data == other def __repr__(self): return 'ListLike(%s)' % repr(self.data) - + self._test_adapter(ListLike) self._test_list(ListLike) self._test_list_bulk(ListLike) @@ -315,7 +380,7 @@ class CollectionsTest(PersistTest): return self.data == other def __repr__(self): return 'ListIsh(%s)' % repr(self.data) - + self._test_adapter(ListIsh) self._test_list(ListIsh) self._test_list_bulk(ListIsh) @@ -326,8 +391,9 @@ class CollectionsTest(PersistTest): pass canary = Canary() - manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=typecallable) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable, useobject=True) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -348,10 +414,13 @@ class CollectionsTest(PersistTest): for item in list(direct): direct.remove(item) control.clear() - - # assume add() is available for list tests + addall(creator()) + e = creator() + addall(e) + addall(e) + if hasattr(direct, 'pop'): direct.pop() control.pop() @@ -386,17 +455,46 @@ class CollectionsTest(PersistTest): direct.discard(e) self.assert_(e not in canary.removed) assert_eq() - + if hasattr(direct, 'update'): + zap() e = creator() addall(e) - + values = set([e, creator(), creator()]) direct.update(values) control.update(values) assert_eq() + if hasattr(direct, '__ior__'): + zap() + e = creator() + addall(e) + + values = set([e, creator(), creator()]) + + direct |= values + control |= values + assert_eq() + + # cover self-assignment short-circuit + values = set([e, creator(), creator()]) + obj.attr |= values + control |= values + assert_eq() + + values = frozenset([e, creator()]) + obj.attr |= values + control |= values + assert_eq() + + try: + direct |= [e, creator()] + assert False + except TypeError: + assert True + if hasattr(direct, 'clear'): addall(creator(), creator()) direct.clear() @@ -405,6 +503,7 @@ class CollectionsTest(PersistTest): if hasattr(direct, 'difference_update'): zap() + e = creator() addall(creator(), creator()) values = set([creator()]) @@ -416,6 +515,36 @@ class CollectionsTest(PersistTest): control.difference_update(values) assert_eq() + if hasattr(direct, '__isub__'): + zap() + e = creator() + addall(creator(), creator()) + values = set([creator()]) + + direct -= values + control -= values + assert_eq() + values.update(set([e, creator()])) + direct -= values + control -= values + assert_eq() + + values = set([creator()]) + obj.attr -= values + control -= values + assert_eq() + + values = frozenset([creator()]) + obj.attr -= values + control -= values + assert_eq() + + try: + direct -= [e, creator()] + assert False + except TypeError: + assert True + if hasattr(direct, 'intersection_update'): zap() e = creator() @@ -431,6 +560,32 @@ class CollectionsTest(PersistTest): control.intersection_update(values) assert_eq() + if hasattr(direct, '__iand__'): + zap() + e = creator() + addall(e, creator(), creator()) + values = set(control) + + direct &= values + control &= values + assert_eq() + + values.update(set([e, creator()])) + direct &= values + control &= values + assert_eq() + + values.update(set([creator()])) + obj.attr &= values + control &= values + assert_eq() + + try: + direct &= [e, creator()] + assert False + except TypeError: + assert True + if hasattr(direct, 'symmetric_difference_update'): zap() e = creator() @@ -453,13 +608,47 @@ class CollectionsTest(PersistTest): control.symmetric_difference_update(values) assert_eq() + if hasattr(direct, '__ixor__'): + zap() + e = creator() + addall(e, creator(), creator()) + + values = set([e, creator()]) + direct ^= values + control ^= values + assert_eq() + + e = creator() + addall(e) + values = set([e]) + direct ^= values + control ^= values + assert_eq() + + values = set() + direct ^= values + control ^= values + assert_eq() + + values = set([creator()]) + obj.attr ^= values + control ^= values + assert_eq() + + try: + direct ^= [e, creator()] + assert False + except TypeError: + assert True + def _test_set_bulk(self, typecallable, creator=entity_maker): class Foo(object): pass canary = Canary() - manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=typecallable) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable, useobject=True) obj = Foo() direct = obj.attr @@ -478,7 +667,7 @@ class CollectionsTest(PersistTest): self.assert_(obj.attr == set([e2])) self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) - + e3 = creator() real_set = set([e3]) obj.attr = real_set @@ -486,12 +675,12 @@ class CollectionsTest(PersistTest): self.assert_(obj.attr == set([e3])) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) - + e4 = creator() try: obj.attr = [e4] self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(e4 not in canary.data) self.assert_(e3 in canary.data) @@ -534,7 +723,7 @@ class CollectionsTest(PersistTest): def test_set_emulates(self): class SetIsh(object): - __emulates__ = set + __emulates__ = py_set def __init__(self): self.data = set() def add(self, item): @@ -562,8 +751,9 @@ class CollectionsTest(PersistTest): pass canary = Canary() - manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=typecallable) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable, useobject=True) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -584,7 +774,7 @@ class CollectionsTest(PersistTest): for item in list(adapter): direct.remove(item) control.clear() - + # assume an 'set' method is available for tests addall(creator()) @@ -619,7 +809,7 @@ class CollectionsTest(PersistTest): direct.clear() control.clear() assert_eq() - + direct.clear() control.clear() assert_eq() @@ -642,7 +832,7 @@ class CollectionsTest(PersistTest): zap() e = creator() addall(e) - + direct.popitem() control.popitem() assert_eq() @@ -669,18 +859,20 @@ class CollectionsTest(PersistTest): control.update(d) assert_eq() - kw = dict([(ee.a, ee) for ee in [e, creator()]]) - direct.update(**kw) - control.update(**kw) - assert_eq() + if sys.version_info >= (2, 4): + kw = dict([(ee.a, ee) for ee in [e, creator()]]) + direct.update(**kw) + control.update(**kw) + assert_eq() def _test_dict_bulk(self, typecallable, creator=dictable_entity): class Foo(object): pass canary = Canary() - manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=typecallable) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable, useobject=True) obj = Foo() direct = obj.attr @@ -700,23 +892,42 @@ class CollectionsTest(PersistTest): self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) + + # key validity on bulk assignment is a basic feature of MappedCollection + # but is not present in basic, @converter-less dict collections. e3 = creator() - real_dict = dict(keyignored1=e3) - obj.attr = real_dict - self.assert_(obj.attr is not real_dict) - self.assert_('keyignored1' not in obj.attr) - self.assert_(set(collections.collection_adapter(obj.attr)) == set([e3])) - self.assert_(e2 in canary.removed) - self.assert_(e3 in canary.added) + if isinstance(obj.attr, collections.MappedCollection): + real_dict = dict(badkey=e3) + try: + obj.attr = real_dict + self.assert_(False) + except TypeError: + pass + self.assert_(obj.attr is not real_dict) + self.assert_('badkey' not in obj.attr) + self.assertEquals(set(collections.collection_adapter(obj.attr)), + set([e2])) + self.assert_(e3 not in canary.added) + else: + real_dict = dict(keyignored1=e3) + obj.attr = real_dict + self.assert_(obj.attr is not real_dict) + self.assert_('keyignored1' not in obj.attr) + self.assertEquals(set(collections.collection_adapter(obj.attr)), + set([e3])) + self.assert_(e2 in canary.removed) + self.assert_(e3 in canary.added) + + obj.attr = typecallable() + self.assertEquals(list(collections.collection_adapter(obj.attr)), []) e4 = creator() try: obj.attr = [e4] self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(e4 not in canary.data) - self.assert_(e3 in canary.data) - + def test_dict(self): try: self._test_adapter(dict, dictable_entity, @@ -851,10 +1062,11 @@ class CollectionsTest(PersistTest): def _test_object(self, typecallable, creator=entity_maker): class Foo(object): pass - + canary = Canary() - manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=typecallable) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable, useobject=True) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -876,7 +1088,7 @@ class CollectionsTest(PersistTest): direct.zark(e) control.remove(e) assert_eq() - + e = creator() direct.maybe_zark(e) control.discard(e) @@ -948,13 +1160,116 @@ class CollectionsTest(PersistTest): self.assert_(getattr(MyCollection2, '_sa_instrumented') == id(MyCollection2)) + def test_recipes(self): + class Custom(object): + def __init__(self): + self.data = [] + @collection.appender + @collection.adds('entity') + def put(self, entity): + self.data.append(entity) + + @collection.remover + @collection.removes(1) + def remove(self, entity): + self.data.remove(entity) + + @collection.adds(1) + def push(self, *args): + self.data.append(args[0]) + + @collection.removes('entity') + def yank(self, entity, arg): + self.data.remove(entity) + + @collection.replaces(2) + def replace(self, arg, entity, **kw): + self.data.insert(0, entity) + return self.data.pop() + + @collection.removes_return() + def pop(self, key): + return self.data.pop() + + @collection.iterator + def __iter__(self): + return iter(self.data) + + class Foo(object): + pass + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=Custom, useobject=True) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = list() + def assert_eq(): + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(list(direct) == control) + creator = entity_maker + + e1 = creator() + direct.put(e1) + control.append(e1) + assert_eq() + + e2 = creator() + direct.put(entity=e2) + control.append(e2) + assert_eq() + + direct.remove(e2) + control.remove(e2) + assert_eq() + + direct.remove(entity=e1) + control.remove(e1) + assert_eq() + + e3 = creator() + direct.push(e3) + control.append(e3) + assert_eq() + + direct.yank(e3, 'blah') + control.remove(e3) + assert_eq() + + e4, e5, e6, e7 = creator(), creator(), creator(), creator() + direct.put(e4) + direct.put(e5) + control.append(e4) + control.append(e5) + + dr1 = direct.replace('foo', e6, bar='baz') + control.insert(0, e6) + cr1 = control.pop() + assert_eq() + self.assert_(dr1 is cr1) + + dr2 = direct.replace(arg=1, entity=e7) + control.insert(0, e7) + cr2 = control.pop() + assert_eq() + self.assert_(dr2 is cr2) + + dr3 = direct.pop('blah') + cr3 = control.pop() + assert_eq() + self.assert_(dr3 is cr3) + def test_lifecycle(self): class Foo(object): pass canary = Canary() creator = entity_maker - manager.register_attribute(Foo, 'attr', True, extension=canary) + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', True, extension=canary, useobject=True) obj = Foo() col1 = obj.attr @@ -976,24 +1291,24 @@ class CollectionsTest(PersistTest): col1.append(e3) self.assert_(e3 not in canary.data) self.assert_(collections.collection_adapter(col1) is None) - + obj.attr[0] = e3 self.assert_(e3 in canary.data) class DictHelpersTest(ORMTest): def define_tables(self, metadata): global parents, children, Parent, Child - + parents = Table('parents', metadata, Column('id', Integer, primary_key=True), - Column('label', String)) + Column('label', String(128))) children = Table('children', metadata, Column('id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('parents.id'), nullable=False), - Column('a', String), - Column('b', String), - Column('c', String)) + Column('a', String(128)), + Column('b', String(128)), + Column('c', String(128))) class Parent(object): def __init__(self, label=None): @@ -1010,7 +1325,7 @@ class DictHelpersTest(ORMTest): 'children': relation(Child, collection_class=collection_class, cascade="all, delete-orphan") }) - + p = Parent() p.children['foo'] = Child('foo', 'value') p.children['bar'] = Child('bar', 'value') @@ -1027,16 +1342,15 @@ class DictHelpersTest(ORMTest): collections.collection_adapter(p.children).append_with_event( Child('foo', 'newvalue')) - - session.save(p) + session.flush() session.clear() - + p = session.query(Parent).get(pid) - + self.assert_(set(p.children.keys()) == set(['foo', 'bar'])) self.assert_(p.children['foo'].id != cid) - + self.assert_(len(list(collections.collection_adapter(p.children))) == 2) session.flush() session.clear() @@ -1046,7 +1360,7 @@ class DictHelpersTest(ORMTest): collections.collection_adapter(p.children).remove_with_event( p.children['foo']) - + self.assert_(len(list(collections.collection_adapter(p.children))) == 1) session.flush() session.clear() @@ -1061,7 +1375,7 @@ class DictHelpersTest(ORMTest): p = session.query(Parent).get(pid) self.assert_(len(list(collections.collection_adapter(p.children))) == 0) - + def _test_composite_mapped(self, collection_class): mapper(Child, children) @@ -1069,7 +1383,7 @@ class DictHelpersTest(ORMTest): 'children': relation(Child, collection_class=collection_class, cascade="all, delete-orphan") }) - + p = Parent() p.children[('foo', '1')] = Child('foo', '1', 'value 1') p.children[('foo', '2')] = Child('foo', '2', 'value 2') @@ -1079,7 +1393,7 @@ class DictHelpersTest(ORMTest): session.flush() pid = p.id session.clear() - + p = session.query(Parent).get(pid) self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) @@ -1087,18 +1401,17 @@ class DictHelpersTest(ORMTest): collections.collection_adapter(p.children).append_with_event( Child('foo', '1', 'newvalue')) - - session.save(p) + session.flush() session.clear() - + p = session.query(Parent).get(pid) - + self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) self.assert_(p.children[('foo', '1')].id != cid) - + self.assert_(len(list(collections.collection_adapter(p.children))) == 2) - + def test_mapped_collection(self): collection_class = collections.mapped_collection(lambda c: c.a) self._test_scalar_mapped(collection_class) @@ -1136,5 +1449,311 @@ class DictHelpersTest(ORMTest): collection_class = lambda: Ordered2(lambda v: (v.a, v.b)) self._test_composite_mapped(collection_class) +# TODO: are these tests redundant vs. the above tests ? +# remove if so +class CustomCollectionsTest(ORMTest): + def define_tables(self, metadata): + global sometable, someothertable + sometable = Table('sometable', metadata, + Column('col1',Integer, primary_key=True), + Column('data', String(30))) + someothertable = Table('someothertable', metadata, + Column('col1', Integer, primary_key=True), + Column('scol1', Integer, ForeignKey(sometable.c.col1)), + Column('data', String(20)) + ) + def test_basic(self): + class MyList(list): + pass + class Foo(object): + pass + class Bar(object): + pass + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, collection_class=MyList) + }) + mapper(Bar, someothertable) + f = Foo() + assert isinstance(f.bars, MyList) + + def test_lazyload(self): + """test that a 'set' can be used as a collection and can lazyload.""" + class Foo(object): + pass + class Bar(object): + pass + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, collection_class=set) + }) + mapper(Bar, someothertable) + f = Foo() + f.bars.add(Bar()) + f.bars.add(Bar()) + sess = create_session() + sess.save(f) + sess.flush() + sess.clear() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + f.bars.clear() + + def test_dict(self): + """test that a 'dict' can be used as a collection and can lazyload.""" + + class Foo(object): + pass + class Bar(object): + pass + class AppenderDict(dict): + @collection.appender + def set(self, item): + self[id(item)] = item + @collection.remover + def remove(self, item): + if id(item) in self: + del self[id(item)] + + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, collection_class=AppenderDict) + }) + mapper(Bar, someothertable) + f = Foo() + f.bars.set(Bar()) + f.bars.set(Bar()) + sess = create_session() + sess.save(f) + sess.flush() + sess.clear() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + f.bars.clear() + + def test_dict_wrapper(self): + """test that the supplied 'dict' wrapper can be used as a collection and can lazyload.""" + + class Foo(object): + pass + class Bar(object): + def __init__(self, data): self.data = data + + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, + collection_class=collections.column_mapped_collection(someothertable.c.data)) + }) + mapper(Bar, someothertable) + + f = Foo() + col = collections.collection_adapter(f.bars) + col.append_with_event(Bar('a')) + col.append_with_event(Bar('b')) + sess = create_session() + sess.save(f) + sess.flush() + sess.clear() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + + existing = set([id(b) for b in f.bars.values()]) + + col = collections.collection_adapter(f.bars) + col.append_with_event(Bar('b')) + f.bars['a'] = Bar('a') + sess.flush() + sess.clear() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + + replaced = set([id(b) for b in f.bars.values()]) + self.assert_(existing != replaced) + + def test_list(self): + class Parent(object): + pass + class Child(object): + pass + + mapper(Parent, sometable, properties={ + 'children':relation(Child, collection_class=list) + }) + mapper(Child, someothertable) + + control = list() + p = Parent() + + o = Child() + control.append(o) + p.children.append(o) + assert control == p.children + assert control == list(p.children) + + o = [Child(), Child(), Child(), Child()] + control.extend(o) + p.children.extend(o) + assert control == p.children + assert control == list(p.children) + + assert control[0] == p.children[0] + assert control[-1] == p.children[-1] + assert control[1:3] == p.children[1:3] + + del control[1] + del p.children[1] + assert control == p.children + assert control == list(p.children) + + o = [Child()] + control[1:3] = o + p.children[1:3] = o + assert control == p.children + assert control == list(p.children) + + o = [Child(), Child(), Child(), Child()] + control[1:3] = o + p.children[1:3] = o + assert control == p.children + assert control == list(p.children) + + o = [Child(), Child(), Child(), Child()] + control[-1:-2] = o + p.children[-1:-2] = o + assert control == p.children + assert control == list(p.children) + + o = [Child(), Child(), Child(), Child()] + control[4:] = o + p.children[4:] = o + assert control == p.children + assert control == list(p.children) + + o = Child() + control.insert(0, o) + p.children.insert(0, o) + assert control == p.children + assert control == list(p.children) + + o = Child() + control.insert(3, o) + p.children.insert(3, o) + assert control == p.children + assert control == list(p.children) + + o = Child() + control.insert(999, o) + p.children.insert(999, o) + assert control == p.children + assert control == list(p.children) + + del control[0:1] + del p.children[0:1] + assert control == p.children + assert control == list(p.children) + + del control[1:1] + del p.children[1:1] + assert control == p.children + assert control == list(p.children) + + del control[1:3] + del p.children[1:3] + assert control == p.children + assert control == list(p.children) + + del control[7:] + del p.children[7:] + assert control == p.children + assert control == list(p.children) + + assert control.pop() == p.children.pop() + assert control == p.children + assert control == list(p.children) + + assert control.pop(0) == p.children.pop(0) + assert control == p.children + assert control == list(p.children) + + assert control.pop(2) == p.children.pop(2) + assert control == p.children + assert control == list(p.children) + + o = Child() + control.insert(2, o) + p.children.insert(2, o) + assert control == p.children + assert control == list(p.children) + + control.remove(o) + p.children.remove(o) + assert control == p.children + assert control == list(p.children) + + def test_custom(self): + class Parent(object): + pass + class Child(object): + pass + + class MyCollection(object): + def __init__(self): + self.data = [] + @collection.appender + def append(self, value): + self.data.append(value) + @collection.remover + def remove(self, value): + self.data.remove(value) + @collection.iterator + def __iter__(self): + return iter(self.data) + + mapper(Parent, sometable, properties={ + 'children':relation(Child, collection_class=MyCollection) + }) + mapper(Child, someothertable) + + control = list() + p1 = Parent() + + o = Child() + control.append(o) + p1.children.append(o) + assert control == list(p1.children) + + o = Child() + control.append(o) + p1.children.append(o) + assert control == list(p1.children) + + o = Child() + control.append(o) + p1.children.append(o) + assert control == list(p1.children) + + sess = create_session() + sess.save(p1) + sess.flush() + sess.clear() + + p2 = sess.query(Parent).get(p1.col1) + o = list(p2.children) + assert len(o) == 3 + + +class InstrumentationTest(TestBase): + + def test_uncooperative_descriptor_in_sweep(self): + class DoNotTouch(object): + def __get__(self, obj, owner): + raise AttributeError + + class Touchy(list): + no_touch = DoNotTouch() + + assert 'no_touch' in Touchy.__dict__ + assert not hasattr(Touchy, 'no_touch') + assert 'no_touch' in dir(Touchy) + + instrumented = collections._instrument_class(Touchy) + assert True + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/compile.py b/test/orm/compile.py index 23f04db856..31b6860623 100644 --- a/test/orm/compile.py +++ b/test/orm/compile.py @@ -1,19 +1,20 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions from sqlalchemy.orm import * from testlib import * -class CompileTest(AssertMixin): +class CompileTest(TestBase, AssertsExecutionResults): """test various mapper compilation scenarios""" - def tearDownAll(self): + def tearDown(self): clear_mappers() - + def testone(self): global metadata, order, employee, product, tax, orderproduct - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) - order = Table('orders', metadata, + order = Table('orders', metadata, Column('id', Integer, primary_key=True), Column('employee_id', Integer, ForeignKey('employees.id'), nullable=False), Column('type', Unicode(16))) @@ -46,9 +47,9 @@ class CompileTest(AssertMixin): order_join = order.select().alias('pjoin') - order_mapper = mapper(Order, order, - select_table=order_join, - polymorphic_on=order_join.c.type, + order_mapper = mapper(Order, order, + select_table=order_join, + polymorphic_on=order_join.c.type, polymorphic_identity='order', properties={ 'orderproducts': relation(OrderProduct, lazy=True, backref='order')} @@ -64,7 +65,7 @@ class CompileTest(AssertMixin): 'orders': relation(Order, lazy=True, backref='employee')}) mapper(OrderProduct, orderproduct) - + # this requires that the compilation of order_mapper's "surrogate mapper" occur after # the initial setup of MapperProperty objects on the mapper. class_mapper(Product).compile() @@ -72,9 +73,9 @@ class CompileTest(AssertMixin): def testtwo(self): """test that conflicting backrefs raises an exception""" global metadata, order, employee, product, tax, orderproduct - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) - order = Table('orders', metadata, + order = Table('orders', metadata, Column('id', Integer, primary_key=True), Column('type', Unicode(16))) @@ -99,9 +100,9 @@ class CompileTest(AssertMixin): order_join = order.select().alias('pjoin') - order_mapper = mapper(Order, order, - select_table=order_join, - polymorphic_on=order_join.c.type, + order_mapper = mapper(Order, order, + select_table=order_join, + polymorphic_on=order_join.c.type, polymorphic_identity='order', properties={ 'orderproducts': relation(OrderProduct, lazy=True, backref='product')} @@ -118,15 +119,15 @@ class CompileTest(AssertMixin): class_mapper(Product).compile() assert False except exceptions.ArgumentError, e: - assert str(e).index("Backrefs do not match") > -1 + assert str(e).index("Error creating backref ") > -1 def testthree(self): - metadata = MetaData(testbase.db) - node_table = Table("node", metadata, + metadata = MetaData(testing.db) + node_table = Table("node", metadata, Column('node_id', Integer, primary_key=True), Column('name_index', Integer, nullable=True), ) - node_name_table = Table("node_name", metadata, + node_name_table = Table("node_name", metadata, Column('node_name_id', Integer, primary_key=True), Column('node_id', Integer, ForeignKey('node.node_id')), Column('host_id', Integer, ForeignKey('host.host_id')), @@ -143,7 +144,7 @@ class CompileTest(AssertMixin): class Node(object):pass class NodeName(object):pass class Host(object):pass - + node_mapper = mapper(Node, node_table) host_mapper = mapper(Host, host_table) node_name_mapper = mapper(NodeName, node_name_table, @@ -157,5 +158,27 @@ class CompileTest(AssertMixin): finally: metadata.drop_all() + def testfour(self): + meta = MetaData() + + a = Table('a', meta, Column('id', Integer, primary_key=True)) + b = Table('b', meta, Column('id', Integer, primary_key=True), Column('a_id', Integer, ForeignKey('a.id'))) + + class A(object):pass + class B(object):pass + + mapper(A, a, properties={ + 'b':relation(B, backref='a') + }) + mapper(B, b, properties={ + 'a':relation(A, backref='b') + }) + + try: + compile_mappers() + assert False + except exceptions.ArgumentError, e: + assert str(e).index("Error creating backref") > -1 + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/orm/cycles.py b/test/orm/cycles.py index ce3065f777..f956a4529b 100644 --- a/test/orm/cycles.py +++ b/test/orm/cycles.py @@ -1,14 +1,15 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * from testlib.tables import * -"""test cyclical mapper relationships. Many of the assertions are provided -via running with postgres, which is strict about foreign keys. +""" +Tests cyclical mapper relationships. -we might want to try an automated generate of much of this, all combos of T1<->T2, with -o2m or m2o between them, and a third T3 with o2m/m2o to one/both T1/T2. +We might want to try an automated generate of much of this, all combos of +T1<->T2, with o2m or m2o between them, and a third T3 with o2m/m2o to one/both +T1/T2. """ @@ -18,13 +19,13 @@ class Tester(object): print repr(self) + " (%d)" % (id(self)) def __repr__(self): return "%s(%s)" % (self.__class__.__name__, repr(self.data)) - -class SelfReferentialTest(AssertMixin): + +class SelfReferentialTest(TestBase, AssertsExecutionResults): """tests a self-referential mapper, with an additional list of child objects.""" def setUpAll(self): global t1, t2, metadata - metadata = MetaData(testbase.db) - t1 = Table('t1', metadata, + metadata = MetaData(testing.db) + t1 = Table('t1', metadata, Column('c1', Integer, Sequence('t1c1_id_seq', optional=True), primary_key=True), Column('parent_c1', Integer, ForeignKey('t1.c1')), Column('data', String(20)) @@ -39,7 +40,7 @@ class SelfReferentialTest(AssertMixin): metadata.drop_all() def setUp(self): clear_mappers() - + def testsingle(self): class C1(Tester): pass @@ -54,13 +55,13 @@ class SelfReferentialTest(AssertMixin): sess.flush() sess.delete(a) sess.flush() - + def testmanytooneonly(self): """test that the circular dependency sort can assemble a many-to-one dependency processor when only the object on the "many" side is actually in the list of modified objects. this requires that the circular sort add the other side of the relation into the UOWTransaction so that the dependency operation can be tacked onto it. - + This also affects inheritance relationships since they rely upon circular sort as well. """ class C1(Tester): @@ -79,16 +80,16 @@ class SelfReferentialTest(AssertMixin): sess.save(c2) sess.flush() assert c2.parent_c1==c1.c1 - + def testcycle(self): class C1(Tester): pass class C2(Tester): pass - + m1 = mapper(C1, t1, properties = { 'c1s' : relation(C1, cascade="all"), - 'c2s' : relation(mapper(C2, t2), private=True) + 'c2s' : relation(mapper(C2, t2), cascade="all, delete-orphan") }) a = C1('head c1') @@ -101,15 +102,15 @@ class SelfReferentialTest(AssertMixin): sess = create_session( ) sess.save(a) sess.flush() - + sess.delete(a) sess.flush() -class SelfReferentialNoPKTest(AssertMixin): +class SelfReferentialNoPKTest(TestBase, AssertsExecutionResults): """test self-referential relationship that joins on a column other than the primary key column""" def setUpAll(self): global table, meta - meta = MetaData(testbase.db) + meta = MetaData(testing.db) table = Table('item', meta, Column('id', Integer, primary_key=True), Column('uuid', String(32), unique=True, nullable=False), @@ -132,7 +133,7 @@ class SelfReferentialNoPKTest(AssertMixin): s.save(t1) s.flush() s.clear() - t = s.query(TT).get_by(id=t1.id) + t = s.query(TT).filter_by(id=t1.id).one() assert t.children[0].parent_uuid == t1.uuid def testlazyclause(self): class TT(object): @@ -147,14 +148,14 @@ class SelfReferentialNoPKTest(AssertMixin): s.flush() s.clear() - t = s.query(TT).get_by(id=t2.id) + t = s.query(TT).filter_by(id=t2.id).one() assert t.uuid == t2.uuid assert t.parent.uuid == t1.uuid - -class InheritTestOne(AssertMixin): + +class InheritTestOne(TestBase, AssertsExecutionResults): def setUpAll(self): global parent, child1, child2, meta - meta = MetaData(testbase.db) + meta = MetaData(testing.db) parent = Table("parent", meta, Column("id", Integer, primary_key=True), Column("parent_data", String(50)), @@ -203,12 +204,12 @@ class InheritTestOne(AssertMixin): session.flush() session.clear() - c1 = session.query(Child1).get_by(child1_data="qwerty") + c1 = session.query(Child1).filter_by(child1_data="qwerty").one() c2 = Child2() c2.child1 = c1 c2.child2_data = "asdfgh" session.save(c2) - # the flush will fail if the UOW does not set up a many-to-one DP + # the flush will fail if the UOW does not set up a many-to-one DP # attached to a task corresponding to c1, since "child1_id" is not nullable session.flush() @@ -218,18 +219,18 @@ class InheritTestTwo(ORMTest): create duplicate entries in the final sort""" def define_tables(self, metadata): global a, b, c - a = Table('a', metadata, + a = Table('a', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), Column('cid', Integer, ForeignKey('c.id')), ) - b = Table('b', metadata, + b = Table('b', metadata, Column('id', Integer, ForeignKey("a.id"), primary_key=True), Column('data', String(30)), ) - c = Table('c', metadata, + c = Table('c', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), Column('aid', Integer, ForeignKey('a.id', use_alter=True, name="foo")), @@ -255,7 +256,7 @@ class InheritTestTwo(ORMTest): cobj = C() sess.save(cobj) sess.flush() - + class BiDirectionalManyToOneTest(ORMTest): def define_tables(self, metadata): @@ -276,12 +277,12 @@ class BiDirectionalManyToOneTest(ORMTest): Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), Column('t2id', Integer, ForeignKey('t2.id'), nullable=False), ) - + def test_reflush(self): class T1(object):pass class T2(object):pass class T3(object):pass - + mapper(T1, t1, properties={ 't2':relation(T2, primaryjoin=t1.c.t2id==t2.c.id) }) @@ -292,13 +293,13 @@ class BiDirectionalManyToOneTest(ORMTest): 't1':relation(T1), 't2':relation(T2) }) - + o1 = T1() o1.t2 = T2() sess = create_session() sess.save(o1) sess.flush() - + # the bug here is that the dependency sort comes up with T1/T2 in a cycle, but there # are no T1/T2 objects to be saved. therefore no "cyclical subtree" gets generated, # and one or the other of T1/T2 gets lost, and processors on T3 dont fire off. @@ -308,7 +309,7 @@ class BiDirectionalManyToOneTest(ORMTest): o3.t2 = o1.t2 sess.save(o3) sess.flush() - + def test_reflush_2(self): """a variant on test_reflush()""" @@ -345,19 +346,19 @@ class BiDirectionalManyToOneTest(ORMTest): o3b.t1 = o1a o3b.t2 = o2a sess.save(o3b) - + o3 = T3() o3.t1 = o1 o3.t2 = o1.t2 sess.save(o3) sess.flush() - -class BiDirectionalOneToManyTest(AssertMixin): + +class BiDirectionalOneToManyTest(TestBase, AssertsExecutionResults): """tests two mappers with a one-to-many relation to each other.""" def setUpAll(self): global t1, t2, metadata - metadata = MetaData(testbase.db) - t1 = Table('t1', metadata, + metadata = MetaData(testing.db) + t1 = Table('t1', metadata, Column('c1', Integer, Sequence('t1c1_id_seq', optional=True), primary_key=True), Column('c2', Integer, ForeignKey('t2.c1')) ) @@ -373,7 +374,7 @@ class BiDirectionalOneToManyTest(AssertMixin): def testcycle(self): class C1(object):pass class C2(object):pass - + m2 = mapper(C2, t2, properties={ 'c1s': relation(C1, primaryjoin=t2.c.c1==t1.c.c2, uselist=True) }) @@ -393,12 +394,12 @@ class BiDirectionalOneToManyTest(AssertMixin): [sess.save(x) for x in [a,b,c,d,e,f]] sess.flush() -class BiDirectionalOneToManyTest2(AssertMixin): +class BiDirectionalOneToManyTest2(TestBase, AssertsExecutionResults): """tests two mappers with a one-to-many relation to each other, with a second one-to-many on one of the mappers""" def setUpAll(self): global t1, t2, t3, metadata - metadata = MetaData(testbase.db) - t1 = Table('t1', metadata, + metadata = MetaData(testing.db) + t1 = Table('t1', metadata, Column('c1', Integer, Sequence('t1c1_id_seq', optional=True), primary_key=True), Column('c2', Integer, ForeignKey('t2.c1')), ) @@ -406,25 +407,25 @@ class BiDirectionalOneToManyTest2(AssertMixin): Column('c1', Integer, Sequence('t2c1_id_seq', optional=True), primary_key=True), Column('c2', Integer, ForeignKey('t1.c1', use_alter=True, name='t1c1_fq')), ) - t3 = Table('t1_data', metadata, + t3 = Table('t1_data', metadata, Column('c1', Integer, Sequence('t1dc1_id_seq', optional=True), primary_key=True), Column('t1id', Integer, ForeignKey('t1.c1')), Column('data', String(20))) metadata.create_all() - + def tearDown(self): clear_mappers() def tearDownAll(self): metadata.drop_all() - + def testcycle(self): class C1(object):pass class C2(object):pass class C1Data(object): def __init__(self, data=None): self.data = data - + m2 = mapper(C2, t2, properties={ 'c1s': relation(C1, primaryjoin=t2.c.c1==t1.c.c2, uselist=True) }) @@ -432,7 +433,7 @@ class BiDirectionalOneToManyTest2(AssertMixin): 'c2s' : relation(C2, primaryjoin=t1.c.c1==t2.c.c2, uselist=True), 'data' : relation(mapper(C1Data, t3)) }) - + a = C1() b = C2() c = C1() @@ -453,14 +454,14 @@ class BiDirectionalOneToManyTest2(AssertMixin): sess.delete(c) sess.flush() -class OneToManyManyToOneTest(AssertMixin): +class OneToManyManyToOneTest(TestBase, AssertsExecutionResults): """tests two mappers, one has a one-to-many on the other mapper, the other has a separate many-to-one relationship to the first. two tests will have a row for each item that is dependent on the other. without the "post_update" flag, such relationships raise an exception when dependencies are sorted.""" def setUpAll(self): global metadata - metadata = MetaData(testbase.db) - global person + metadata = MetaData(testing.db) + global person global ball ball = Table('ball', metadata, Column('id', Integer, Sequence('ball_id_seq', optional=True), primary_key=True), @@ -474,15 +475,15 @@ class OneToManyManyToOneTest(AssertMixin): ) metadata.create_all() - + def tearDownAll(self): metadata.drop_all() - + def tearDown(self): clear_mappers() def testcycle(self): - """this test has a peculiar aspect in that it doesnt create as many dependent + """this test has a peculiar aspect in that it doesnt create as many dependent relationships as the other tests, and revealed a small glitch in the circular dependency sorting.""" class Person(object): pass @@ -517,7 +518,7 @@ class OneToManyManyToOneTest(AssertMixin): Ball.mapper = mapper(Ball, ball) Person.mapper = mapper(Person, person, properties= dict( - balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id, post_update=False, private=True), + balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id, post_update=False, cascade="all, delete-orphan"), favorateBall = relation(Ball.mapper, primaryjoin=person.c.favorite_ball_id==ball.c.id, remote_side=person.c.favorite_ball_id, post_update=True), ) ) @@ -532,8 +533,8 @@ class OneToManyManyToOneTest(AssertMixin): sess = create_session() sess.save(b) sess.save(p) - - self.assert_sql(testbase.db, lambda: sess.flush(), [ + + self.assert_sql(testing.db, lambda: sess.flush(), [ ( "INSERT INTO person (favorite_ball_id, data) VALUES (:favorite_ball_id, :data)", {'favorite_ball_id': None, 'data':'some data'} @@ -558,7 +559,7 @@ class OneToManyManyToOneTest(AssertMixin): "UPDATE person SET favorite_ball_id=:favorite_ball_id WHERE person.id = :person_id", lambda ctx:{'favorite_ball_id':p.favorateBall.id,'person_id':p.id} ) - ], + ], with_sequences= [ ( "INSERT INTO person (id, favorite_ball_id, data) VALUES (:id, :favorite_ball_id, :data)", @@ -580,14 +581,14 @@ class OneToManyManyToOneTest(AssertMixin): "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)", lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id, 'data':'some data'} ), - # heres the post update + # heres the post update ( "UPDATE person SET favorite_ball_id=:favorite_ball_id WHERE person.id = :person_id", lambda ctx:{'favorite_ball_id':p.favorateBall.id,'person_id':p.id} ) ]) sess.delete(p) - self.assert_sql(testbase.db, lambda: sess.flush(), [ + self.assert_sql(testing.db, lambda: sess.flush(), [ # heres the post update (which is a pre-update with deletes) ( "UPDATE person SET favorite_ball_id=:favorite_ball_id WHERE person.id = :person_id", @@ -606,7 +607,7 @@ class OneToManyManyToOneTest(AssertMixin): ]) - + def testpostupdate_o2m(self): """tests a cycle between two rows, with a post_update on the one-to-many""" class Person(object): @@ -619,7 +620,7 @@ class OneToManyManyToOneTest(AssertMixin): Ball.mapper = mapper(Ball, ball) Person.mapper = mapper(Person, person, properties= dict( - balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id, private=True, post_update=True, backref='person'), + balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id, cascade="all, delete-orphan", post_update=True, backref='person'), favorateBall = relation(Ball.mapper, primaryjoin=person.c.favorite_ball_id==ball.c.id, remote_side=person.c.favorite_ball_id), ) ) @@ -637,7 +638,7 @@ class OneToManyManyToOneTest(AssertMixin): sess = create_session() [sess.save(x) for x in [b,p,b2,b3,b4]] - self.assert_sql(testbase.db, lambda: sess.flush(), [ + self.assert_sql(testing.db, lambda: sess.flush(), [ ( "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)", {'person_id':None, 'data':'some data'} @@ -716,7 +717,7 @@ class OneToManyManyToOneTest(AssertMixin): ]) sess.delete(p) - self.assert_sql(testbase.db, lambda: sess.flush(), [ + self.assert_sql(testing.db, lambda: sess.flush(), [ ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", lambda ctx:{'person_id': None, 'ball_id': b.id} @@ -743,11 +744,11 @@ class OneToManyManyToOneTest(AssertMixin): ) ]) -class SelfReferentialPostUpdateTest(AssertMixin): +class SelfReferentialPostUpdateTest(TestBase, AssertsExecutionResults): """test using post_update on a single self-referential mapper""" def setUpAll(self): global metadata, node_table - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) node_table = Table('node', metadata, Column('id', Integer, Sequence('nodeid_id_seq', optional=True), primary_key=True), Column('path', String(50), nullable=False), @@ -758,10 +759,10 @@ class SelfReferentialPostUpdateTest(AssertMixin): node_table.create() def tearDownAll(self): node_table.drop() - + def testbasic(self): """test that post_update only fires off when needed. - + this test case used to produce many superfluous update statements, particularly upon delete""" class Node(object): def __init__(self, path=''): @@ -795,11 +796,11 @@ class SelfReferentialPostUpdateTest(AssertMixin): session = create_session() def append_child(parent, child): - if len(parent.children): + if parent.children: parent.children[-1].next_sibling = child child.prev_sibling = parent.children[-1] parent.children.append(child) - + def remove_child(parent, child): child.parent = None node = child.next_sibling @@ -828,7 +829,7 @@ class SelfReferentialPostUpdateTest(AssertMixin): remove_child(root, cats) # pre-trigger lazy loader on 'cats' to make the test easier cats.children - self.assert_sql(testbase.db, lambda: session.flush(), [ + self.assert_sql(testing.db, lambda: session.flush(), [ ( "UPDATE node SET prev_sibling_id=:prev_sibling_id WHERE node.id = :node_id", lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id} @@ -847,22 +848,22 @@ class SelfReferentialPostUpdateTest(AssertMixin): ), ]) -class SelfReferentialPostUpdateTest2(AssertMixin): +class SelfReferentialPostUpdateTest2(TestBase, AssertsExecutionResults): def setUpAll(self): global metadata, a_table - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) a_table = Table("a", metadata, Column("id", Integer(), primary_key=True), - Column("fui", String()), + Column("fui", String(128)), Column("b", Integer(), ForeignKey("a.id")), ) a_table.create() def tearDownAll(self): a_table.drop() def testbasic(self): - """test that post_update remembers to be involved in update operations as well, + """test that post_update remembers to be involved in update operations as well, since it replaces the normal dependency processing completely [ticket:413]""" - class a(object): + class a(object): def __init__(self, fui): self.fui = fui @@ -882,12 +883,11 @@ class SelfReferentialPostUpdateTest2(AssertMixin): # to fire off anyway session.save(f2) session.flush() - + session.clear() f1 = session.query(a).get(f1.id) f2 = session.query(a).get(f2.id) assert f2.foo is f1 - -if __name__ == "__main__": - testbase.main() +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py new file mode 100644 index 0000000000..c38b278238 --- /dev/null +++ b/test/orm/dynamic.py @@ -0,0 +1,331 @@ +import testenv; testenv.configure_for_tests() +import operator +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * +from testlib.fixtures import * + +from query import QueryTest + +class DynamicTest(FixtureTest): + keep_mappers = False + refresh_data = True + + def test_basic(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + q = sess.query(User) + + print q.filter(User.id==7).all() + u = q.filter(User.id==7).first() + print list(u.addresses) + assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all() + assert fixtures.user_address_result == q.all() + + def test_order_by(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u = sess.query(User).get(8) + self.assertEquals(list(u.addresses.order_by(desc(Address.email_address))), [Address(email_address=u'ed@wood.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@bettyboop.com')]) + + def test_configured_order_by(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), order_by=desc(Address.email_address)) + }) + sess = create_session() + u = sess.query(User).get(8) + self.assertEquals(list(u.addresses), [Address(email_address=u'ed@wood.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@bettyboop.com')]) + + def test_count(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u = sess.query(User).first() + assert u.addresses.count() == 1, u.addresses.count() + + def test_backref(self): + mapper(Address, addresses, properties={ + 'user':relation(User, backref=backref('addresses', lazy='dynamic')) + }) + mapper(User, users) + + sess = create_session() + ad = sess.query(Address).get(1) + def go(): + ad.user = None + self.assert_sql_count(testing.db, go, 1) + sess.flush() + u = sess.query(User).get(7) + assert ad not in u.addresses + + def test_no_count(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + q = sess.query(User) + + # dynamic collection cannot implement __len__() (at least one that returns a live database + # result), else additional count() queries are issued when evaluating in a list context + def go(): + assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all() + self.assert_sql_count(testing.db, go, 2) + + def test_m2m(self): + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy="dynamic", backref=backref('orders', lazy="dynamic")) + }) + mapper(Item, items) + + sess = create_session() + o1 = Order(id=15, description="order 10") + i1 = Item(id=10, description="item 8") + o1.items.append(i1) + sess.save(o1) + sess.flush() + + assert o1 in i1.orders.all() + assert i1 in o1.items.all() + + def test_transient_detached(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u1 = User() + u1.addresses.append(Address()) + assert u1.addresses.count() == 1 + assert u1.addresses[0] == Address() + +class FlushTest(FixtureTest): + def test_basic(self): + class Fixture(Base): + pass + + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u1 = User(name='jack') + u2 = User(name='ed') + u2.addresses.append(Address(email_address='foo@bar.com')) + u1.addresses.append(Address(email_address='lala@hoho.com')) + sess.save(u1) + sess.save(u2) + sess.flush() + + sess.clear() + + # test the test fixture a little bit + assert User(name='jack', addresses=[Address(email_address='wrong')]) != sess.query(User).first() + assert User(name='jack', addresses=[Address(email_address='lala@hoho.com')]) == sess.query(User).first() + + assert [ + User(name='jack', addresses=[Address(email_address='lala@hoho.com')]), + User(name='ed', addresses=[Address(email_address='foo@bar.com')]) + ] == sess.query(User).all() + + @testing.fails_on('maxdb') + def test_delete_nocascade(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), backref='user') + }) + sess = create_session(autoflush=True) + u = User(name='ed') + u.addresses.append(Address(email_address='a')) + u.addresses.append(Address(email_address='b')) + u.addresses.append(Address(email_address='c')) + u.addresses.append(Address(email_address='d')) + u.addresses.append(Address(email_address='e')) + u.addresses.append(Address(email_address='f')) + sess.save(u) + + assert Address(email_address='c') == u.addresses[2] + sess.delete(u.addresses[2]) + sess.delete(u.addresses[4]) + sess.delete(u.addresses[3]) + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) + + sess.clear() + u = sess.query(User).get(u.id) + + sess.delete(u) + + # u.addresses relation will have to force the load + # of all addresses so that they can be updated + sess.flush() + sess.close() + + assert testing.db.scalar(addresses.count(addresses.c.user_id != None)) ==0 + + @testing.fails_on('maxdb') + def test_delete_cascade(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), backref='user', cascade="all, delete-orphan") + }) + sess = create_session(autoflush=True) + u = User(name='ed') + u.addresses.append(Address(email_address='a')) + u.addresses.append(Address(email_address='b')) + u.addresses.append(Address(email_address='c')) + u.addresses.append(Address(email_address='d')) + u.addresses.append(Address(email_address='e')) + u.addresses.append(Address(email_address='f')) + sess.save(u) + + assert Address(email_address='c') == u.addresses[2] + sess.delete(u.addresses[2]) + sess.delete(u.addresses[4]) + sess.delete(u.addresses[3]) + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) + + sess.clear() + u = sess.query(User).get(u.id) + + sess.delete(u) + + # u.addresses relation will have to force the load + # of all addresses so that they can be updated + sess.flush() + sess.close() + + assert testing.db.scalar(addresses.count()) ==0 + + @testing.fails_on('maxdb') + def test_remove_orphans(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), cascade="all, delete-orphan", backref='user') + }) + sess = create_session(autoflush=True) + u = User(name='ed') + u.addresses.append(Address(email_address='a')) + u.addresses.append(Address(email_address='b')) + u.addresses.append(Address(email_address='c')) + u.addresses.append(Address(email_address='d')) + u.addresses.append(Address(email_address='e')) + u.addresses.append(Address(email_address='f')) + sess.save(u) + + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='c'), + Address(email_address='d'), Address(email_address='e'), Address(email_address='f')] == sess.query(Address).all() + + assert Address(email_address='c') == u.addresses[2] + + try: + del u.addresses[3] + assert False + except TypeError, e: + assert "doesn't support item deletion" in str(e), str(e) + + for a in u.addresses.filter(Address.email_address.in_(['c', 'e', 'f'])): + u.addresses.remove(a) + + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) + + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == sess.query(Address).all() + + sess.delete(u) + sess.close() + + +def create_backref_test(autoflush, saveuser): + def test_backref(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), backref='user') + }) + sess = create_session(autoflush=autoflush) + + u = User(name='buffy') + + a = Address(email_address='foo@bar.com') + a.user = u + + if saveuser: + sess.save(u) + else: + sess.save(a) + + if not autoflush: + sess.flush() + + assert u in sess + assert a in sess + + self.assert_(list(u.addresses) == [a]) + + a.user = None + if not autoflush: + self.assert_(list(u.addresses) == [a]) + + if not autoflush: + sess.flush() + self.assert_(list(u.addresses) == []) + + test_backref = _function_named( + test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""), + (saveuser and "_saveuser" or "_savead"))) + setattr(FlushTest, test_backref.__name__, test_backref) + +for autoflush in (False, True): + for saveuser in (False, True): + create_backref_test(autoflush, saveuser) + +class DontDereferenceTest(ORMTest): + def define_tables(self, metadata): + global users_table, addresses_table + + users_table = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(40)), + Column('fullname', String(100)), + Column('password', String(15))) + + addresses_table = Table('addresses', metadata, + Column('id', Integer, primary_key=True), + Column('email_address', String(100), nullable=False), + Column('user_id', Integer, ForeignKey('users.id'))) + def test_no_deref(self): + mapper(User, users_table, properties={ + 'addresses': relation(Address, backref='user', lazy='dynamic') + }) + + mapper(Address, addresses_table) + + session = create_session() + user = User() + user.name = 'joe' + user.fullname = 'Joe User' + user.password = 'Joe\'s secret' + address = Address() + address.email_address = 'joe@joesdomain.example' + address.user = user + session.save(user) + session.flush() + session.clear() + + def query1(): + session = create_session(testing.db) + user = session.query(User).first() + return user.addresses.all() + + def query2(): + session = create_session(testing.db) + return session.query(User).first().addresses.all() + + def query3(): + session = create_session(testing.db) + user = session.query(User).first() + return session.query(User).first().addresses.all() + + self.assertEquals(query1(), [Address(email_address='joe@joesdomain.example')] ) + self.assertEquals(query2(), [Address(email_address='joe@joesdomain.example')] ) + self.assertEquals(query3(), [Address(email_address='joe@joesdomain.example')] ) + + +if __name__ == '__main__': + testenv.main() diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index a109be56f0..418df83dda 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -1,17 +1,15 @@ """basic tests of eager loaded attributes""" -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * -from fixtures import * +from testlib.fixtures import * from query import QueryTest -class EagerTest(QueryTest): +class EagerTest(FixtureTest): keep_mappers = False - - def setup_mappers(self): - pass + keep_data = True def test_basic(self): mapper(User, users, properties={ @@ -44,19 +42,39 @@ class EagerTest(QueryTest): assert [ User(id=7, addresses=[ Address(id=1) - ]), + ]), User(id=8, addresses=[ Address(id=3, email_address='ed@bettyboop.com'), Address(id=4, email_address='ed@lala.com'), Address(id=2, email_address='ed@wood.com') - ]), + ]), User(id=9, addresses=[ Address(id=5) - ]), + ]), User(id=10, addresses=[]) ] == q.all() - def test_orderby_secondary(self): + def test_orderby_multi(self): + mapper(User, users, properties = { + 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=[addresses.c.email_address, addresses.c.id]), + }) + q = create_session().query(User) + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=8, addresses=[ + Address(id=3, email_address='ed@bettyboop.com'), + Address(id=4, email_address='ed@lala.com'), + Address(id=2, email_address='ed@wood.com') + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=10, addresses=[]) + ] == q.all() + + def test_orderby_related(self): """tests that a regular mapper select on a single table can order by a relation to a second table""" mapper(Address, addresses) @@ -64,22 +82,22 @@ class EagerTest(QueryTest): mapper(User, users, properties = dict( addresses = relation(Address, lazy=False), )) - + q = create_session().query(User) l = q.filter(User.id==Address.user_id).order_by(Address.email_address).all() - + assert [ User(id=8, addresses=[ Address(id=2, email_address='ed@wood.com'), Address(id=3, email_address='ed@bettyboop.com'), Address(id=4, email_address='ed@lala.com'), - ]), + ]), User(id=9, addresses=[ Address(id=5) - ]), + ]), User(id=7, addresses=[ Address(id=1) - ]), + ]), ] == l def test_orderby_desc(self): @@ -92,18 +110,75 @@ class EagerTest(QueryTest): assert [ User(id=7, addresses=[ Address(id=1) - ]), + ]), User(id=8, addresses=[ Address(id=2, email_address='ed@wood.com'), Address(id=4, email_address='ed@lala.com'), Address(id=3, email_address='ed@bettyboop.com'), - ]), + ]), User(id=9, addresses=[ Address(id=5) - ]), + ]), User(id=10, addresses=[]) ] == sess.query(User).all() + def test_deferred_fk_col(self): + mapper(Address, addresses, properties={ + 'user_id':deferred(addresses.c.user_id), + 'user':relation(User, lazy=False) + }) + mapper(User, users) + + assert [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))] == create_session().query(Address).filter(Address.id.in_([1, 4, 5])).all() + + assert [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))] == create_session().query(Address).filter(Address.id.in_([1, 4, 5])).limit(3).all() + + sess = create_session() + a = sess.query(Address).get(1) + def go(): + assert a.user_id==7 + # assert that the eager loader added 'user_id' to the row + # and deferred loading of that col was disabled + self.assert_sql_count(testing.db, go, 0) + + # do the mapping in reverse + # (we would have just used an "addresses" backref but the test fixtures then require the whole + # backref to be set up, lazy loaders trigger, etc.) + clear_mappers() + + mapper(Address, addresses, properties={ + 'user_id':deferred(addresses.c.user_id), + }) + mapper(User, users, properties={'addresses':relation(Address, lazy=False)}) + + assert [User(id=7, addresses=[Address(id=1)])] == create_session().query(User).filter(User.id==7).all() + + assert [User(id=7, addresses=[Address(id=1)])] == create_session().query(User).limit(1).filter(User.id==7).all() + + sess = create_session() + u = sess.query(User).get(7) + def go(): + assert u.addresses[0].user_id==7 + # assert that the eager loader didn't have to affect 'user_id' here + # and that its still deferred + self.assert_sql_count(testing.db, go, 1) + + clear_mappers() + + mapper(User, users, properties={'addresses':relation(Address, lazy=False)}) + mapper(Address, addresses, properties={ + 'user_id':deferred(addresses.c.user_id), + 'dingalings':relation(Dingaling, lazy=False) + }) + mapper(Dingaling, dingalings, properties={ + 'address_id':deferred(dingalings.c.address_id) + }) + sess = create_session() + def go(): + u = sess.query(User).limit(1).get(8) + assert User(id=8, addresses=[Address(id=2, dingalings=[Dingaling(id=1)]), Address(id=3), Address(id=4)]) == u + self.assert_sql_count(testing.db, go, 1) + def test_many_to_many(self): mapper(Keyword, keywords) @@ -114,39 +189,43 @@ class EagerTest(QueryTest): q = create_session().query(Item) def go(): assert fixtures.item_keyword_result == q.all() - self.assert_sql_count(testbase.db, go, 1) - + self.assert_sql_count(testing.db, go, 1) + def go(): assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(keywords.c.name == 'red').all() - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) + + def go(): + assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(keywords.c.name == 'red').all() + self.assert_sql_count(testing.db, go, 1) def test_eager_option(self): mapper(Keyword, keywords) mapper(Item, items, properties = dict( - keywords = relation(Keyword, secondary=item_keywords, lazy=True), + keywords = relation(Keyword, secondary=item_keywords, lazy=True, order_by=keywords.c.id), )) q = create_session().query(Item) def go(): assert fixtures.item_keyword_result[0:2] == q.options(eagerload('keywords')).join('keywords').filter(keywords.c.name == 'red').all() - - self.assert_sql_count(testbase.db, go, 1) + + self.assert_sql_count(testing.db, go, 1) def test_cyclical(self): """test that a circular eager relationship breaks the cycle with a lazy loader""" - + mapper(Address, addresses) mapper(User, users, properties = dict( addresses = relation(Address, lazy=False, backref=backref('user', lazy=False)) )) assert class_mapper(User).get_property('addresses').lazy is False assert class_mapper(Address).get_property('user').lazy is False - + sess = create_session() assert fixtures.user_address_result == sess.query(User).all() - + def test_double(self): """tests eager loading with two relations simulatneously, from the same table, using aliases. """ openorders = alias(orders, 'openorders') @@ -184,7 +263,7 @@ class EagerTest(QueryTest): User(id=10) ] == q.all() - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) def test_double_same_mappers(self): """tests eager loading with two relations simulatneously, from the same table, using aliases. """ @@ -224,30 +303,52 @@ class EagerTest(QueryTest): User(id=10) ] == q.all() - self.assert_sql_count(testbase.db, go, 1) - + self.assert_sql_count(testing.db, go, 1) + + def test_no_false_hits(self): + """test that eager loaders don't interpret main table columns as part of their eager load.""" + + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False), + 'orders':relation(Order, lazy=False) + }) + mapper(Address, addresses) + mapper(Order, orders) + + allusers = create_session().query(User).all() + + # using a textual select, the columns will be 'id' and 'name'. + # the eager loaders have aliases which should not hit on those columns, they should + # be required to locate only their aliased/fully table qualified column name. + noeagers = create_session().query(User).from_statement("select * from users").all() + assert 'orders' not in noeagers[0].__dict__ + assert 'addresses' not in noeagers[0].__dict__ + + @testing.fails_on('maxdb') def test_limit(self): """test limit operations combined with lazy-load relationships.""" mapper(Item, items) mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, lazy=False) + 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) }) mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=False), + 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id), 'orders':relation(Order, lazy=True) }) sess = create_session() q = sess.query(User) - if testbase.db.engine.name == 'mssql': + if testing.against('mysql'): l = q.limit(2).all() assert fixtures.user_all_result[:2] == l - else: - l = q.limit(2).offset(1).all() + else: + l = q.limit(2).offset(1).order_by(User.id).all() + print fixtures.user_all_result[1:3] + print l assert fixtures.user_all_result[1:3] == l - + def test_distinct(self): # this is an involved 3x union of the users table to get a lot of rows. # then see if the "distinct" works its way out. you actually get the same @@ -265,25 +366,27 @@ class EagerTest(QueryTest): def go(): l = q.filter(s.c.u2_id==User.c.id).distinct().all() assert fixtures.user_address_result == l - self.assert_sql_count(testbase.db, go, 1) - + self.assert_sql_count(testing.db, go, 1) + + @testing.fails_on('maxdb') def test_limit_2(self): mapper(Keyword, keywords) mapper(Item, items, properties = dict( keywords = relation(Keyword, secondary=item_keywords, lazy=False, order_by=[keywords.c.id]), )) - + sess = create_session() q = sess.query(Item) l = q.filter((Item.c.description=='item 2') | (Item.c.description=='item 5') | (Item.c.description=='item 3')).\ order_by(Item.id).limit(2).all() assert fixtures.item_keyword_result[1:3] == l - + + @testing.fails_on('maxdb') def test_limit_3(self): - """test that the ORDER BY is propigated from the inner select to the outer select, when using the + """test that the ORDER BY is propigated from the inner select to the outer select, when using the 'wrapped' select statement resulting from the combination of eager loading and limit/offset clauses.""" - + mapper(Item, items) mapper(Order, orders, properties = dict( items = relation(Item, secondary=order_items, lazy=False) @@ -295,105 +398,123 @@ class EagerTest(QueryTest): orders = relation(Order, lazy=False), )) sess = create_session() - + q = sess.query(User) - if testbase.db.engine.name != 'mssql': - l = q.join('orders').order_by(desc(Order.user_id)).limit(2).offset(1) + if not testing.against('maxdb', 'mssql'): + l = q.join('orders').order_by(Order.user_id.desc()).limit(2).offset(1) assert [ - User(id=9, + User(id=9, orders=[Order(id=2), Order(id=4)], addresses=[Address(id=5)] ), - User(id=7, + User(id=7, orders=[Order(id=1), Order(id=3), Order(id=5)], addresses=[Address(id=1)] ) ] == l.all() - l = q.join('addresses').order_by(desc(Address.email_address)).limit(1).offset(0) + l = q.join('addresses').order_by(Address.email_address.desc()).limit(1).offset(0) assert [ - User(id=7, + User(id=7, orders=[Order(id=1), Order(id=3), Order(id=5)], addresses=[Address(id=1)] ) ] == l.all() + def test_limit_4(self): + # tests the LIMIT/OFFSET aliasing on a mapper against a select. original issue from ticket #904 + sel = select([users, addresses.c.email_address], users.c.id==addresses.c.user_id).alias('useralias') + mapper(User, sel, properties={ + 'orders':relation(Order, primaryjoin=sel.c.id==orders.c.user_id, lazy=False) + }) + mapper(Order, orders) + + sess = create_session() + self.assertEquals(sess.query(User).first(), + User(name=u'jack',orders=[ + Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), + Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), + Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5)], + email_address=u'jack@bean.com',id=7) + ) + def test_one_to_many_scalar(self): mapper(User, users, properties = dict( address = relation(mapper(Address, addresses), lazy=False, uselist=False) )) q = create_session().query(User) - + def go(): l = q.filter(users.c.id == 7).all() assert [User(id=7, address=Address(id=1))] == l - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) + @testing.fails_on('maxdb') def test_many_to_one(self): mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy=False) )) sess = create_session() q = sess.query(Address) - + def go(): a = q.filter(addresses.c.id==1).one() assert a.user is not None u1 = sess.query(User).get(7) assert a.user is u1 - self.assert_sql_count(testbase.db, go, 1) - + self.assert_sql_count(testing.db, go, 1) + def test_one_and_many(self): - """tests eager load for a parent object with a child object that + """tests eager load for a parent object with a child object that contains a many-to-many relationship to a third object.""" - + mapper(User, users, properties={ 'orders':relation(Order, lazy=False) }) - mapper(Item, items) + mapper(Item, items) mapper(Order, orders, properties = dict( items = relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) )) - + q = create_session().query(User) - + l = q.filter("users.id in (7, 8, 9)") - + def go(): assert fixtures.user_order_result[0:3] == l.all() - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) def test_double_with_aggregate(self): max_orders_by_user = select([func.max(orders.c.id).label('order_id')], group_by=[orders.c.user_id]).alias('max_orders_by_user') - + max_orders = orders.select(orders.c.id==max_orders_by_user.c.order_id).alias('max_orders') - + mapper(Order, orders) mapper(User, users, properties={ 'orders':relation(Order, backref='user', lazy=False), 'max_order':relation(mapper(Order, max_orders, non_primary=True), lazy=False, uselist=False) }) q = create_session().query(User) - + def go(): assert [ User(id=7, orders=[ Order(id=1), Order(id=3), Order(id=5), - ], + ], max_order=Order(id=5) ), User(id=8, orders=[]), - User(id=9, orders=[Order(id=2),Order(id=4)], + User(id=9, orders=[Order(id=2),Order(id=4)], max_order=Order(id=4) ), User(id=10), ] == q.all() - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) def test_wide(self): mapper(Order, orders, properties={'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id)}) @@ -403,14 +524,14 @@ class EagerTest(QueryTest): orders = relation(Order, lazy = False), )) q = create_session().query(User) - l = q.select() + l = q.all() assert fixtures.user_all_result == q.all() def test_against_select(self): """test eager loading of a mapper which is against a select""" s = select([orders], orders.c.isopen==1).alias('openorders') - + mapper(Order, s, properties={ 'user':relation(User, lazy=False) }) @@ -422,15 +543,15 @@ class EagerTest(QueryTest): Order(id=3, user=User(id=7)), Order(id=4, user=User(id=9)) ] == q.all() - - q = q.select_from(s.join(order_items).join(items)).filter(~Item.id.in_(1, 2, 5)) + + q = q.select_from(s.join(order_items).join(items)).filter(~Item.id.in_([1, 2, 5])) assert [ Order(id=3, user=User(id=7)), ] == q.all() def test_aliasing(self): """test that eager loading uses aliases to insulate the eager load from regular criterion against those tables.""" - + mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy=False) )) @@ -438,19 +559,163 @@ class EagerTest(QueryTest): l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id) assert fixtures.user_address_result[1:2] == l.all() +class AddEntityTest(FixtureTest): + keep_mappers = False + keep_data = True + + def _assert_result(self): + return [ + ( + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=1, + items=[Item(id=1), Item(id=2), Item(id=3)] + ), + ), + ( + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=3, + items=[Item(id=3), Item(id=4), Item(id=5)] + ), + ), + ( + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=5, + items=[Item(id=5)] + ), + ), + ( + User(id=9, + addresses=[Address(id=5)] + ), + Order(id=2, + items=[Item(id=1), Item(id=2), Item(id=3)] + ), + ), + ( + User(id=9, + addresses=[Address(id=5)] + ), + Order(id=4, + items=[Item(id=1), Item(id=5)] + ), + ) + ] + + def test_basic(self): + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False), + 'orders':relation(Order) + }) + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) + }) + mapper(Item, items) + + + sess = create_session() + def go(): + ret = sess.query(User).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all() + self.assertEquals(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 1) + + def test_options(self): + mapper(User, users, properties={ + 'addresses':relation(Address), + 'orders':relation(Order) + }) + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, order_by=items.c.id) + }) + mapper(Item, items) + + sess = create_session() + + def go(): + ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all() + self.assertEquals(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 6) + + sess.clear() + def go(): + ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).options(eagerload('items', Order)).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all() + self.assertEquals(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 1) + +class OrderBySecondaryTest(ORMTest): + def define_tables(self, metadata): + global a, b, m2m + m2m = Table('mtom', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + Column('bid', Integer, ForeignKey('b.id')), + ) + + a = Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + ) + b = Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + ) + + def insert_data(self): + a.insert().execute([ + {'id':1, 'data':'a1'}, + {'id':2, 'data':'a2'} + ]) + + b.insert().execute([ + {'id':1, 'data':'b1'}, + {'id':2, 'data':'b2'}, + {'id':3, 'data':'b3'}, + {'id':4, 'data':'b4'}, + ]) + + m2m.insert().execute([ + {'id':2, 'aid':1, 'bid':1}, + {'id':4, 'aid':2, 'bid':4}, + {'id':1, 'aid':1, 'bid':3}, + {'id':6, 'aid':2, 'bid':2}, + {'id':3, 'aid':1, 'bid':2}, + {'id':5, 'aid':2, 'bid':3}, + ]) + + def test_ordering(self): + class A(Base):pass + class B(Base):pass + + mapper(A, a, properties={ + 'bs':relation(B, secondary=m2m, lazy=False, order_by=m2m.c.id) + }) + mapper(B, b) + + sess = create_session() + self.assertEquals(sess.query(A).all(), [A(data='a1', bs=[B(data='b3'), B(data='b1'), B(data='b2')]), A(bs=[B(data='b4'), B(data='b3'), B(data='b2')])]) + + class SelfReferentialEagerTest(ORMTest): def define_tables(self, metadata): global nodes nodes = Table('nodes', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('node_id_seq', optional=True), primary_key=True), Column('parent_id', Integer, ForeignKey('nodes.id')), Column('data', String(30))) - + + @testing.fails_on('maxdb') def test_basic(self): class Node(Base): def append(self, node): self.children.append(node) - + mapper(Node, nodes, properties={ 'children':relation(Node, lazy=False, join_depth=3) }) @@ -476,8 +741,132 @@ class SelfReferentialEagerTest(ORMTest): ]), Node(data='n13') ]) == d - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) + + + def test_lazy_fallback_doesnt_affect_eager(self): + class Node(Base): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=False, join_depth=1) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.save(n1) + sess.flush() + sess.clear() + + # eager load with join depth 1. when eager load of 'n1' + # hits the children of 'n12', no columns are present, eager loader + # degrades to lazy loader; fine. but then, 'n12' is *also* in the + # first level of columns since we're loading the whole table. + # when those rows arrive, now we *can* eager load its children and an + # eager collection should be initialized. essentially the 'n12' instance + # is present in not just two different rows but two distinct sets of columns + # in this result set. + def go(): + allnodes = sess.query(Node).order_by(Node.data).all() + n12 = allnodes[2] + assert n12.data == 'n12' + print "N12 IS", id(n12) + print [c.data for c in n12.children] + assert [ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ] == list(n12.children) + self.assert_sql_count(testing.db, go, 1) + + def test_with_deferred(self): + class Node(Base): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=False, join_depth=3), + 'data':deferred(nodes.c.data) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + sess.save(n1) + sess.flush() + sess.clear() + + def go(): + assert Node(data='n1', children=[Node(data='n11'), Node(data='n12')]) == sess.query(Node).first() + self.assert_sql_count(testing.db, go, 4) + + sess.clear() + + def go(): + assert Node(data='n1', children=[Node(data='n11'), Node(data='n12')]) == sess.query(Node).options(undefer('data')).first() + self.assert_sql_count(testing.db, go, 3) + + sess.clear() + + def go(): + assert Node(data='n1', children=[Node(data='n11'), Node(data='n12')]) == sess.query(Node).options(undefer('data'), undefer('children.data')).first() + self.assert_sql_count(testing.db, go, 1) + + + + def test_options(self): + class Node(Base): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=True) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.save(n1) + sess.flush() + sess.clear() + def go(): + d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() + assert Node(data='n1', children=[ + Node(data='n11'), + Node(data='n12', children=[ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ]), + Node(data='n13') + ]) == d + self.assert_sql_count(testing.db, go, 2) + + def go(): + d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() + + # test that the query isn't wrapping the initial query for eager loading. + # testing only sqlite for now since the query text is slightly different on other + # dialects + if testing.against('sqlite'): + self.assert_sql(testing.db, go, [ + ( + "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes WHERE nodes.data = :data_1 ORDER BY nodes.oid LIMIT 1 OFFSET 0", + {'data_1': 'n1'} + ), + ]) + @testing.fails_on('maxdb') def test_no_depth(self): class Node(Base): def append(self, node): @@ -508,7 +897,169 @@ class SelfReferentialEagerTest(ORMTest): ]), Node(data='n13') ]) == d - self.assert_sql_count(testbase.db, go, 3) + self.assert_sql_count(testing.db, go, 3) + +class SelfReferentialM2MEagerTest(ORMTest): + def define_tables(self, metadata): + global widget, widget_rel + + widget = Table('widget', metadata, + Column('id', Integer, primary_key=True), + Column('name', Unicode(40), nullable=False, unique=True), + ) + + widget_rel = Table('widget_rel', metadata, + Column('parent_id', Integer, ForeignKey('widget.id')), + Column('child_id', Integer, ForeignKey('widget.id')), + UniqueConstraint('parent_id', 'child_id'), + ) + def test_basic(self): + class Widget(Base): + pass + + mapper(Widget, widget, properties={ + 'children': relation(Widget, secondary=widget_rel, + primaryjoin=widget_rel.c.parent_id==widget.c.id, + secondaryjoin=widget_rel.c.child_id==widget.c.id, + lazy=False, join_depth=1, + ) + }) + sess = create_session() + w1 = Widget(name=u'w1') + w2 = Widget(name=u'w2') + w1.children.append(w2) + sess.save(w1) + sess.flush() + sess.clear() + +# l = sess.query(Widget).filter(Widget.name=='w1').all() +# print l + assert [Widget(name='w1', children=[Widget(name='w2')])] == sess.query(Widget).filter(Widget.name==u'w1').all() + +class CyclicalInheritingEagerTest(ORMTest): + def define_tables(self, metadata): + global t1, t2 + t1 = Table('t1', metadata, + Column('c1', Integer, primary_key=True), + Column('c2', String(30)), + Column('type', String(30)) + ) + + t2 = Table('t2', metadata, + Column('c1', Integer, primary_key=True), + Column('c2', String(30)), + Column('type', String(30)), + Column('t1.id', Integer, ForeignKey('t1.c1'))) + + def test_basic(self): + class T(object): + pass + + class SubT(T): + pass + + class T2(object): + pass + + class SubT2(T2): + pass + + mapper(T, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1') + mapper(SubT, None, inherits=T, polymorphic_identity='subt1', properties={ + 't2s':relation(SubT2, lazy=False, backref=backref('subt', lazy=False)) + }) + mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') + mapper(SubT2, None, inherits=T2, polymorphic_identity='subt2') + + # testing a particular endless loop condition in eager join setup + create_session().query(SubT).all() + +class SubqueryTest(ORMTest): + def define_tables(self, metadata): + global users_table, tags_table + + users_table = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(16)) + ) + + tags_table = Table('tags', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey("users.id")), + Column('score1', Float), + Column('score2', Float), + ) + + def test_label_anonymizing(self): + """test that eager loading works with subqueries with labels, + even if an explicit labelname which conflicts with a label on the parent. + + There's not much reason a column_property() would ever need to have a label + of a specific name (and they don't even need labels these days), + unless you'd like the name to line up with a name + that you may be using for a straight textual statement used for loading + instances of that type. + + """ + class User(Base): + @property + def prop_score(self): + return sum([tag.prop_score for tag in self.tags]) + + class Tag(Base): + @property + def prop_score(self): + return self.score1 * self.score2 + + for labeled, labelname in [(True, 'score'), (True, None), (False, None)]: + clear_mappers() + + tag_score = (tags_table.c.score1 * tags_table.c.score2) + user_score = select([func.sum(tags_table.c.score1 * + tags_table.c.score2)], + tags_table.c.user_id == users_table.c.id) + + if labeled: + tag_score = tag_score.label(labelname) + user_score = user_score.label(labelname) + else: + user_score = user_score.as_scalar() + + mapper(Tag, tags_table, properties={ + 'query_score': column_property(tag_score), + }) + + + mapper(User, users_table, properties={ + 'tags': relation(Tag, backref='user', lazy=False), + 'query_score': column_property(user_score), + }) + + session = create_session() + session.save(User(name='joe', tags=[Tag(score1=5.0, score2=3.0), Tag(score1=55.0, score2=1.0)])) + session.save(User(name='bar', tags=[Tag(score1=5.0, score2=4.0), Tag(score1=50.0, score2=1.0), Tag(score1=15.0, score2=2.0)])) + session.flush() + session.clear() + + def go(): + for user in session.query(User).all(): + self.assertEquals(user.query_score, user.prop_score) + self.assert_sql_count(testing.db, go, 1) + + + # fails for non labeled (fixed in 0.5): + if labeled: + def go(): + u = session.query(User).filter_by(name='joe').one() + self.assertEquals(u.query_score, u.prop_score) + self.assert_sql_count(testing.db, go, 1) + else: + u = session.query(User).filter_by(name='joe').one() + self.assertEquals(u.query_score, u.prop_score) + + for t in (tags_table, users_table): + t.delete().execute() + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/orm/entity.py b/test/orm/entity.py index da76e8df05..760f8fce90 100644 --- a/test/orm/entity.py +++ b/test/orm/entity.py @@ -1,92 +1,109 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext from testlib import * from testlib.tables import * -class EntityTest(AssertMixin): +class EntityTest(TestBase, AssertsExecutionResults): """tests mappers that are constructed based on "entity names", which allows the same class to have multiple primary mappers """ + + @testing.uses_deprecated('SessionContext') def setUpAll(self): global user1, user2, address1, address2, metadata, ctx - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) ctx = SessionContext(create_session) - - user1 = Table('user1', metadata, - Column('user_id', Integer, Sequence('user1_id_seq'), primary_key=True), + + user1 = Table('user1', metadata, + Column('user_id', Integer, Sequence('user1_id_seq', optional=True), + primary_key=True), Column('name', String(60), nullable=False) ) - user2 = Table('user2', metadata, - Column('user_id', Integer, Sequence('user2_id_seq'), primary_key=True), + user2 = Table('user2', metadata, + Column('user_id', Integer, Sequence('user2_id_seq', optional=True), + primary_key=True), Column('name', String(60), nullable=False) ) address1 = Table('address1', metadata, - Column('address_id', Integer, Sequence('address1_id_seq'), primary_key=True), - Column('user_id', Integer, ForeignKey(user1.c.user_id), nullable=False), + Column('address_id', Integer, + Sequence('address1_id_seq', optional=True), + primary_key=True), + Column('user_id', Integer, ForeignKey(user1.c.user_id), + nullable=False), Column('email', String(100), nullable=False) ) address2 = Table('address2', metadata, - Column('address_id', Integer, Sequence('address2_id_seq'), primary_key=True), - Column('user_id', Integer, ForeignKey(user2.c.user_id), nullable=False), + Column('address_id', Integer, + Sequence('address2_id_seq', optional=True), + primary_key=True), + Column('user_id', Integer, ForeignKey(user2.c.user_id), + nullable=False), Column('email', String(100), nullable=False) ) metadata.create_all() def tearDownAll(self): metadata.drop_all() def tearDown(self): + ctx.current.clear() clear_mappers() for t in metadata.table_iterator(reverse=True): t.delete().execute() + @testing.uses_deprecated('SessionContextExt') def testbasic(self): """tests a pair of one-to-many mapper structures, establishing that both parent and child objects honor the "entity_name" attribute attached to the object instances.""" class User(object):pass class Address(object):pass - + a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.mapper_extension) - a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension) + a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension) u1mapper = mapper(User, user1, entity_name='user1', properties ={ 'addresses':relation(a1mapper) }, extension=ctx.mapper_extension) u2mapper =mapper(User, user2, entity_name='user2', properties={ 'addresses':relation(a2mapper) }, extension=ctx.mapper_extension) - + u1 = User(_sa_entity_name='user1') u1.name = 'this is user 1' a1 = Address(_sa_entity_name='address1') a1.email='a1@foo.com' u1.addresses.append(a1) - + u2 = User(_sa_entity_name='user2') u2.name='this is user 2' a2 = Address(_sa_entity_name='address2') a2.email='a2@foo.com' u2.addresses.append(a2) - + ctx.current.flush() assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] - assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] - assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')] ctx.current.clear() - u1list = ctx.current.query(User, entity_name='user1').select() - u2list = ctx.current.query(User, entity_name='user2').select() + u1list = ctx.current.query(User, entity_name='user1').all() + u2list = ctx.current.query(User, entity_name='user2').all() assert len(u1list) == len(u2list) == 1 assert u1list[0] is not u2list[0] assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + u1 = ctx.current.query(User, entity_name='user1').first() + ctx.current.refresh(u1) + ctx.current.expire(u1) + + def testcascade(self): """same as testbasic but relies on session cascading""" class User(object):pass class Address(object):pass a1mapper = mapper(Address, address1, entity_name='address1') - a2mapper = mapper(Address, address2, entity_name='address2') + a2mapper = mapper(Address, address2, entity_name='address2') u1mapper = mapper(User, user1, entity_name='user1', properties ={ 'addresses':relation(a1mapper) }) @@ -109,16 +126,16 @@ class EntityTest(AssertMixin): u2.addresses.append(a2) sess.save(u2, entity_name='user2') print u2.__dict__ - + sess.flush() assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] - assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] - assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')] sess.clear() - u1list = sess.query(User, entity_name='user1').select() - u2list = sess.query(User, entity_name='user2').select() + u1list = sess.query(User, entity_name='user1').all() + u2list = sess.query(User, entity_name='user2').all() assert len(u1list) == len(u2list) == 1 assert u1list[0] is not u2list[0] assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 @@ -128,9 +145,9 @@ class EntityTest(AssertMixin): class User(object):pass class Address1(object):pass class Address2(object):pass - + a1mapper = mapper(Address1, address1, extension=ctx.mapper_extension) - a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension) + a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension) u1mapper = mapper(User, user1, entity_name='user1', properties ={ 'addresses':relation(a1mapper) }, extension=ctx.mapper_extension) @@ -153,12 +170,12 @@ class EntityTest(AssertMixin): ctx.current.flush() assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] - assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] - assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')] ctx.current.clear() - u1list = ctx.current.query(User, entity_name='user1').select() - u2list = ctx.current.query(User, entity_name='user2').select() + u1list = ctx.current.query(User, entity_name='user1').all() + u2list = ctx.current.query(User, entity_name='user2').all() assert len(u1list) == len(u2list) == 1 assert u1list[0] is not u2list[0] assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 @@ -166,7 +183,7 @@ class EntityTest(AssertMixin): # is setting up for each load assert isinstance(u1list[0].addresses[0], Address1) assert isinstance(u2list[0].addresses[0], Address2) - + def testpolymorphic_deferred(self): """test that deferred columns load properly using entity names""" class User(object):pass @@ -188,8 +205,8 @@ class EntityTest(AssertMixin): assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] ctx.current.clear() - u1list = ctx.current.query(User, entity_name='user1').select() - u2list = ctx.current.query(User, entity_name='user2').select() + u1list = ctx.current.query(User, entity_name='user1').all() + u2list = ctx.current.query(User, entity_name='user2').all() assert len(u1list) == len(u2list) == 1 assert u1list[0] is not u2list[0] # the deferred column load requires that setup_loader() check that the correct DeferredColumnLoader @@ -197,6 +214,6 @@ class EntityTest(AssertMixin): assert u1list[0].name == 'this is user 1' assert u2list[0].name == 'this is user 2' - -if __name__ == "__main__": - testbase.main() + +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/expire.py b/test/orm/expire.py new file mode 100644 index 0000000000..58c05a3820 --- /dev/null +++ b/test/orm/expire.py @@ -0,0 +1,772 @@ +"""test attribute/instance expiration, deferral of attributes, etc.""" + +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy import exceptions +from sqlalchemy.orm import * +from testlib import * +from testlib.fixtures import * +import gc + +class ExpireTest(FixtureTest): + keep_mappers = False + refresh_data = True + + def test_expire(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(7) + assert len(u.addresses) == 1 + u.name = 'foo' + del u.addresses[0] + sess.expire(u) + + assert 'name' not in u.__dict__ + + def go(): + assert u.name == 'jack' + self.assert_sql_count(testing.db, go, 1) + assert 'name' in u.__dict__ + + u.name = 'foo' + sess.flush() + # change the value in the DB + users.update(users.c.id==7, values=dict(name='jack')).execute() + sess.expire(u) + # object isnt refreshed yet, using dict to bypass trigger + assert u.__dict__.get('name') != 'jack' + assert 'name' in u._state.expired_attributes + + sess.query(User).all() + # test that it refreshed + assert u.__dict__['name'] == 'jack' + assert 'name' not in u._state.expired_attributes + + def go(): + assert u.name == 'jack' + self.assert_sql_count(testing.db, go, 0) + + def test_persistence_check(self): + mapper(User, users) + s = create_session() + u = s.get(User, 7) + s.clear() + + self.assertRaisesMessage(exceptions.InvalidRequestError, r"is not persistent within this Session", lambda: s.expire(u)) + + def test_expire_doesntload_on_set(self): + mapper(User, users) + + sess = create_session() + u = sess.query(User).get(7) + + sess.expire(u, attribute_names=['name']) + def go(): + u.name = 'somenewname' + self.assert_sql_count(testing.db, go, 0) + sess.flush() + sess.clear() + assert sess.query(User).get(7).name == 'somenewname' + + def test_no_session(self): + mapper(User, users) + sess = create_session() + u = sess.query(User).get(7) + + sess.expire(u, attribute_names=['name']) + sess.expunge(u) + try: + u.name + except exceptions.InvalidRequestError, e: + assert str(e) == "Instance is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed" + + def test_pending_doesnt_raise(self): + mapper(User, users) + sess = create_session() + u = User(id=15) + sess.save(u) + sess.expire(u, ['name']) + assert u.name is None + + def test_no_instance_key(self): + # this tests an artificial condition such that + # an instance is pending, but has expired attributes. this + # is actually part of a larger behavior when postfetch needs to + # occur during a flush() on an instance that was just inserted + mapper(User, users) + sess = create_session() + u = sess.query(User).get(7) + + sess.expire(u, attribute_names=['name']) + sess.expunge(u) + del u._instance_key + assert 'name' not in u.__dict__ + sess.save(u) + assert u.name == 'jack' + + def test_expire_preserves_changes(self): + """test that the expire load operation doesn't revert post-expire changes""" + + mapper(Order, orders) + sess = create_session() + o = sess.query(Order).get(3) + sess.expire(o) + + o.description = "order 3 modified" + def go(): + assert o.isopen == 1 + self.assert_sql_count(testing.db, go, 1) + assert o.description == 'order 3 modified' + + del o.description + assert "description" not in o.__dict__ + sess.expire(o, ['isopen']) + sess.query(Order).all() + assert o.isopen == 1 + assert "description" not in o.__dict__ + + assert o.description is None + + o.isopen=15 + sess.expire(o, ['isopen', 'description']) + o.description = 'some new description' + sess.query(Order).all() + assert o.isopen == 1 + assert o.description == 'some new description' + + sess.expire(o, ['isopen', 'description']) + sess.query(Order).all() + del o.isopen + def go(): + assert o.isopen is None + self.assert_sql_count(testing.db, go, 0) + + o.isopen=14 + sess.expire(o) + o.description = 'another new description' + sess.query(Order).all() + assert o.isopen == 1 + assert o.description == 'another new description' + + + def test_expire_committed(self): + """test that the committed state of the attribute receives the most recent DB data""" + mapper(Order, orders) + + sess = create_session() + o = sess.query(Order).get(3) + sess.expire(o) + + orders.update(id=3).execute(description='order 3 modified') + assert o.isopen == 1 + assert o._state.dict['description'] == 'order 3 modified' + def go(): + sess.flush() + self.assert_sql_count(testing.db, go, 0) + + def test_expire_cascade(self): + mapper(User, users, properties={ + 'addresses':relation(Address, cascade="all, refresh-expire") + }) + mapper(Address, addresses) + s = create_session() + u = s.get(User, 8) + assert u.addresses[0].email_address == 'ed@wood.com' + + u.addresses[0].email_address = 'someotheraddress' + s.expire(u) + u.name + print u._state.dict + assert u.addresses[0].email_address == 'ed@wood.com' + + def test_expired_lazy(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(7) + + sess.expire(u) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + def go(): + assert u.addresses[0].email_address == 'jack@bean.com' + assert u.name == 'jack' + # two loads + self.assert_sql_count(testing.db, go, 2) + assert 'name' in u.__dict__ + assert 'addresses' in u.__dict__ + + def test_expired_eager(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(7) + + sess.expire(u) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + def go(): + assert u.addresses[0].email_address == 'jack@bean.com' + assert u.name == 'jack' + # two loads, since relation() + scalar are + # separate right now on per-attribute load + self.assert_sql_count(testing.db, go, 2) + assert 'name' in u.__dict__ + assert 'addresses' in u.__dict__ + + sess.expire(u, ['name', 'addresses']) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + def go(): + sess.query(User).filter_by(id=7).one() + assert u.addresses[0].email_address == 'jack@bean.com' + assert u.name == 'jack' + # one load, since relation() + scalar are + # together when eager load used with Query + self.assert_sql_count(testing.db, go, 1) + + def test_relation_changes_preserved(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + sess = create_session() + u = sess.query(User).get(8) + sess.expire(u, ['name', 'addresses']) + u.addresses + assert 'name' not in u.__dict__ + del u.addresses[1] + u.name + assert 'name' in u.__dict__ + assert len(u.addresses) == 2 + + def test_eagerload_props_dontload(self): + # relations currently have to load separately from scalar instances. the use case is: + # expire "addresses". then access it. lazy load fires off to load "addresses", but needs + # foreign key or primary key attributes in order to lazy load; hits those attributes, + # such as below it hits "u.id". "u.id" triggers full unexpire operation, eagerloads + # addresses since lazy=False. this is all wihtin lazy load which fires unconditionally; + # so an unnecessary eagerload (or lazyload) was issued. would prefer not to complicate + # lazyloading to "figure out" that the operation should be aborted right now. + + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + sess = create_session() + u = sess.query(User).get(8) + sess.expire(u) + u.id + assert 'addresses' not in u.__dict__ + u.addresses + assert 'addresses' in u.__dict__ + + def test_expire_synonym(self): + mapper(User, users, properties={ + 'uname':synonym('name') + }) + + sess = create_session() + u = sess.query(User).get(7) + assert 'name' in u.__dict__ + assert u.uname == u.name + + sess.expire(u) + assert 'name' not in u.__dict__ + + users.update(users.c.id==7).execute(name='jack2') + assert u.name == 'jack2' + assert u.uname == 'jack2' + assert 'name' in u.__dict__ + + # this wont work unless we add API hooks through the attr. system + # to provide "expire" behavior on a synonym + #sess.expire(u, ['uname']) + #users.update(users.c.id==7).execute(name='jack3') + #assert u.uname == 'jack3' + + def test_partial_expire(self): + mapper(Order, orders) + + sess = create_session() + o = sess.query(Order).get(3) + + sess.expire(o, attribute_names=['description']) + assert 'id' in o.__dict__ + assert 'description' not in o.__dict__ + assert o._state.dict['isopen'] == 1 + + orders.update(orders.c.id==3).execute(description='order 3 modified') + + def go(): + assert o.description == 'order 3 modified' + self.assert_sql_count(testing.db, go, 1) + assert o._state.dict['description'] == 'order 3 modified' + + o.isopen = 5 + sess.expire(o, attribute_names=['description']) + assert 'id' in o.__dict__ + assert 'description' not in o.__dict__ + assert o.__dict__['isopen'] == 5 + assert o._state.committed_state['isopen'] == 1 + + def go(): + assert o.description == 'order 3 modified' + self.assert_sql_count(testing.db, go, 1) + assert o.__dict__['isopen'] == 5 + assert o._state.dict['description'] == 'order 3 modified' + assert o._state.committed_state['isopen'] == 1 + + sess.flush() + + sess.expire(o, attribute_names=['id', 'isopen', 'description']) + assert 'id' not in o.__dict__ + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + def go(): + assert o.description == 'order 3 modified' + assert o.id == 3 + assert o.isopen == 5 + self.assert_sql_count(testing.db, go, 1) + + def test_partial_expire_lazy(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(8) + + sess.expire(u, ['name', 'addresses']) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + # hit the lazy loader. just does the lazy load, + # doesnt do the overall refresh + def go(): + assert u.addresses[0].email_address=='ed@wood.com' + self.assert_sql_count(testing.db, go, 1) + + assert 'name' not in u.__dict__ + + # check that mods to expired lazy-load attributes + # only do the lazy load + sess.expire(u, ['name', 'addresses']) + def go(): + u.addresses = [Address(id=10, email_address='foo@bar.com')] + self.assert_sql_count(testing.db, go, 1) + + sess.flush() + + # flush has occurred, and addresses was modified, + # so the addresses collection got committed and is + # longer expired + def go(): + assert u.addresses[0].email_address=='foo@bar.com' + assert len(u.addresses) == 1 + self.assert_sql_count(testing.db, go, 0) + + # but the name attribute was never loaded and so + # still loads + def go(): + assert u.name == 'ed' + self.assert_sql_count(testing.db, go, 1) + + def test_partial_expire_eager(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(8) + + sess.expire(u, ['name', 'addresses']) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + def go(): + assert u.addresses[0].email_address=='ed@wood.com' + self.assert_sql_count(testing.db, go, 1) + + # check that mods to expired eager-load attributes + # do the refresh + sess.expire(u, ['name', 'addresses']) + def go(): + u.addresses = [Address(id=10, email_address='foo@bar.com')] + self.assert_sql_count(testing.db, go, 1) + sess.flush() + + # this should ideally trigger the whole load + # but currently it works like the lazy case + def go(): + assert u.addresses[0].email_address=='foo@bar.com' + assert len(u.addresses) == 1 + self.assert_sql_count(testing.db, go, 0) + + def go(): + assert u.name == 'ed' + # scalar attributes have their own load + self.assert_sql_count(testing.db, go, 1) + # ideally, this was already loaded, but we arent + # doing it that way right now + #self.assert_sql_count(testing.db, go, 0) + + def test_relations_load_on_query(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(8) + assert 'name' in u.__dict__ + u.addresses + assert 'addresses' in u.__dict__ + + sess.expire(u, ['name', 'addresses']) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + sess.query(User).options(eagerload('addresses')).filter_by(id=8).all() + assert 'name' in u.__dict__ + assert 'addresses' in u.__dict__ + + def test_partial_expire_deferred(self): + mapper(Order, orders, properties={ + 'description':deferred(orders.c.description) + }) + + sess = create_session() + o = sess.query(Order).get(3) + sess.expire(o, ['description', 'isopen']) + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + + # test that expired attribute access refreshes + # the deferred + def go(): + assert o.isopen == 1 + assert o.description == 'order 3' + self.assert_sql_count(testing.db, go, 1) + + sess.expire(o, ['description', 'isopen']) + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + # test that the deferred attribute triggers the full + # reload + def go(): + assert o.description == 'order 3' + assert o.isopen == 1 + self.assert_sql_count(testing.db, go, 1) + + clear_mappers() + + mapper(Order, orders) + sess.clear() + + # same tests, using deferred at the options level + o = sess.query(Order).options(defer('description')).get(3) + + assert 'description' not in o.__dict__ + + # sanity check + def go(): + assert o.description == 'order 3' + self.assert_sql_count(testing.db, go, 1) + + assert 'description' in o.__dict__ + assert 'isopen' in o.__dict__ + sess.expire(o, ['description', 'isopen']) + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + + # test that expired attribute access refreshes + # the deferred + def go(): + assert o.isopen == 1 + assert o.description == 'order 3' + self.assert_sql_count(testing.db, go, 1) + sess.expire(o, ['description', 'isopen']) + + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + # test that the deferred attribute triggers the full + # reload + def go(): + assert o.description == 'order 3' + assert o.isopen == 1 + self.assert_sql_count(testing.db, go, 1) + + def test_eagerload_query_refreshes(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(8) + assert len(u.addresses) == 3 + sess.expire(u) + assert 'addresses' not in u.__dict__ + print "-------------------------------------------" + sess.query(User).filter_by(id=8).all() + assert 'addresses' in u.__dict__ + assert len(u.addresses) == 3 + + def test_expire_all(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + + sess = create_session() + userlist = sess.query(User).all() + assert fixtures.user_address_result == userlist + assert len(list(sess)) == 9 + sess.expire_all() + gc.collect() + assert len(list(sess)) == 4 # since addresses were gc'ed + + userlist = sess.query(User).all() + u = userlist[1] + assert fixtures.user_address_result == userlist + assert len(list(sess)) == 9 + +class PolymorphicExpireTest(ORMTest): + keep_data = True + + def define_tables(self, metadata): + global people, engineers, Person, Engineer + + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + ) + + class Person(Base): + pass + class Engineer(Person): + pass + + def insert_data(self): + people.insert().execute( + {'person_id':1, 'name':'person1', 'type':'person'}, + {'person_id':2, 'name':'engineer1', 'type':'engineer'}, + {'person_id':3, 'name':'engineer2', 'type':'engineer'}, + ) + engineers.insert().execute( + {'person_id':2, 'status':'new engineer'}, + {'person_id':3, 'status':'old engineer'}, + ) + + def test_poly_select(self): + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') + mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') + + sess = create_session() + [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all() + + sess.expire(p1) + sess.expire(e1, ['status']) + sess.expire(e2) + + for p in [p1, e2]: + assert 'name' not in p.__dict__ + + assert 'name' in e1.__dict__ + assert 'status' not in e2.__dict__ + assert 'status' not in e1.__dict__ + + e1.name = 'new engineer name' + + def go(): + sess.query(Person).all() + self.assert_sql_count(testing.db, go, 3) + + for p in [p1, e1, e2]: + assert 'name' in p.__dict__ + + assert 'status' in e2.__dict__ + assert 'status' in e1.__dict__ + def go(): + assert e1.name == 'new engineer name' + assert e2.name == 'engineer2' + assert e1.status == 'new engineer' + self.assert_sql_count(testing.db, go, 0) + self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1'])) + + def test_poly_deferred(self): + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person', polymorphic_fetch='deferred') + mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') + + sess = create_session() + [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all() + + sess.expire(p1) + sess.expire(e1, ['status']) + sess.expire(e2) + + for p in [p1, e2]: + assert 'name' not in p.__dict__ + + assert 'name' in e1.__dict__ + assert 'status' not in e2.__dict__ + assert 'status' not in e1.__dict__ + + e1.name = 'new engineer name' + + def go(): + sess.query(Person).all() + self.assert_sql_count(testing.db, go, 1) + + for p in [p1, e1, e2]: + assert 'name' in p.__dict__ + + assert 'status' not in e2.__dict__ + assert 'status' not in e1.__dict__ + + def go(): + assert e1.name == 'new engineer name' + assert e2.name == 'engineer2' + assert e1.status == 'new engineer' + assert e2.status == 'old engineer' + self.assert_sql_count(testing.db, go, 2) + self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1'])) + + +class RefreshTest(FixtureTest): + keep_mappers = False + refresh_data = True + + def test_refresh(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), backref='user') + }) + s = create_session() + u = s.get(User, 7) + u.name = 'foo' + a = Address() + assert object_session(a) is None + u.addresses.append(a) + assert a.email_address is None + assert id(a) in [id(x) for x in u.addresses] + + s.refresh(u) + + # its refreshed, so not dirty + assert u not in s.dirty + + # username is back to the DB + assert u.name == 'jack' + + assert id(a) not in [id(x) for x in u.addresses] + + u.name = 'foo' + u.addresses.append(a) + # now its dirty + assert u in s.dirty + assert u.name == 'foo' + assert id(a) in [id(x) for x in u.addresses] + s.expire(u) + + # get the attribute, it refreshes + print "OK------" +# print u.__dict__ +# print u._state.callables + assert u.name == 'jack' + assert id(a) not in [id(x) for x in u.addresses] + + def test_persistence_check(self): + mapper(User, users) + s = create_session() + u = s.get(User, 7) + s.clear() + self.assertRaisesMessage(exceptions.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u)) + + def test_refresh_expired(self): + mapper(User, users) + s = create_session() + u = s.get(User, 7) + s.expire(u) + assert 'name' not in u.__dict__ + s.refresh(u) + assert u.name == 'jack' + + def test_refresh_with_lazy(self): + """test that when a lazy loader is set as a trigger on an object's attribute + (at the attribute level, not the class level), a refresh() operation doesnt + fire the lazy loader or create any problems""" + + s = create_session() + mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) + q = s.query(User).options(lazyload('addresses')) + u = q.filter(users.c.id==8).first() + def go(): + s.refresh(u) + self.assert_sql_count(testing.db, go, 1) + + + def test_refresh_with_eager(self): + """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders""" + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=False) + }) + + s = create_session() + u = s.get(User, 8) + assert len(u.addresses) == 3 + s.refresh(u) + assert len(u.addresses) == 3 + + s = create_session() + u = s.get(User, 8) + assert len(u.addresses) == 3 + s.expire(u) + assert len(u.addresses) == 3 + + @testing.fails_on('maxdb') + def test_refresh2(self): + """test a hang condition that was occuring on expire/refresh""" + + s = create_session() + mapper(Address, addresses) + + mapper(User, users, properties = dict(addresses=relation(Address,cascade="all, delete-orphan",lazy=False)) ) + + u=User() + u.name='Justin' + a = Address(id=10, email_address='lala') + u.addresses.append(a) + + s.save(u) + s.flush() + s.clear() + u = s.query(User).filter(User.name=='Justin').one() + + s.expire(u) + assert u.name == 'Justin' + + s.refresh(u) + +if __name__ == '__main__': + testenv.main() diff --git a/test/orm/generative.py b/test/orm/generative.py index 4a90c13cb1..aced8f626f 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy import exceptions @@ -12,35 +12,36 @@ class Foo(object): for k in kwargs: setattr(self, k, kwargs[k]) -class GenerativeQueryTest(PersistTest): +class GenerativeQueryTest(TestBase): def setUpAll(self): global foo, metadata - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_id_seq'), primary_key=True), Column('bar', Integer), Column('range', Integer)) - + mapper(Foo, foo) metadata.create_all() - - sess = create_session(bind=testbase.db) + + sess = create_session(bind=testing.db) for i in range(100): sess.save(Foo(bar=i, range=i%10)) sess.flush() - + def tearDownAll(self): metadata.drop_all() clear_mappers() - + def test_selectby(self): - res = create_session(bind=testbase.db).query(Foo).filter_by(range=5) + res = create_session(bind=testing.db).query(Foo).filter_by(range=5) assert res.order_by([Foo.c.bar])[0].bar == 5 assert res.order_by([desc(Foo.c.bar)])[0].bar == 95 - + @testing.unsupported('mssql') + @testing.fails_on('maxdb') def test_slice(self): - sess = create_session(bind=testbase.db) + sess = create_session(bind=testing.db) query = sess.query(Foo) orig = query.all() assert query[1] == orig[1] @@ -52,16 +53,9 @@ class GenerativeQueryTest(PersistTest): assert list(query[-5:]) == orig[-5:] assert query[10:20][5] == orig[10:20][5] - @testing.supported('mssql') - def test_slice_mssql(self): - sess = create_session(bind=testbase.db) - query = sess.query(Foo) - orig = query.all() - assert list(query[:10]) == orig[:10] - assert list(query[:10]) == orig[:10] - + @testing.uses_deprecated('Call to deprecated function apply_max') def test_aggregate(self): - sess = create_session(bind=testbase.db) + sess = create_session(bind=testing.db) query = sess.query(Foo) assert query.count() == 100 assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0 @@ -69,54 +63,57 @@ class GenerativeQueryTest(PersistTest): assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29 assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29 - @testing.unsupported('mysql') def test_aggregate_1(self): - # this one fails in mysql as the result comes back as a string - query = create_session(bind=testbase.db).query(Foo) + if (testing.against('mysql') and + testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')): + return + + query = create_session(bind=testing.db).query(Foo) assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435 - @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql') + @testing.fails_on('firebird', 'mssql') def test_aggregate_2(self): - query = create_session(bind=testbase.db).query(Foo) - assert query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5 - - @testing.supported('postgres', 'mysql', 'firebird', 'mssql') - def test_aggregate_2_int(self): - query = create_session(bind=testbase.db).query(Foo) - assert int(query.filter(foo.c.bar<30).avg(foo.c.bar)) == 14 + query = create_session(bind=testing.db).query(Foo) + avg = query.filter(foo.c.bar < 30).avg(foo.c.bar) + assert round(avg, 1) == 14.5 - @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql') + @testing.fails_on('firebird', 'mssql') + @testing.uses_deprecated('Call to deprecated function apply_avg') def test_aggregate_3(self): - query = create_session(bind=testbase.db).query(Foo) - assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5 - assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one() == 14.5 - + query = create_session(bind=testing.db).query(Foo) + + avg_f = query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() + assert round(avg_f, 1) == 14.5 + + avg_o = query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one() + assert round(avg_o, 1) == 14.5 + def test_filter(self): - query = create_session(bind=testbase.db).query(Foo) + query = create_session(bind=testing.db).query(Foo) assert query.count() == 100 assert query.filter(Foo.c.bar < 30).count() == 30 res2 = query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10) assert res2.count() == 19 - + def test_options(self): - query = create_session(bind=testbase.db).query(Foo) + query = create_session(bind=testing.db).query(Foo) class ext1(MapperExtension): def populate_instance(self, mapper, selectcontext, row, instance, **flags): instance.TEST = "hello world" - return EXT_PASS + return EXT_CONTINUE assert query.options(extension(ext1()))[0].TEST == "hello world" - + def test_order_by(self): - query = create_session(bind=testbase.db).query(Foo) + query = create_session(bind=testing.db).query(Foo) assert query.order_by([Foo.c.bar])[0].bar == 0 assert query.order_by([desc(Foo.c.bar)])[0].bar == 99 def test_offset(self): - query = create_session(bind=testbase.db).query(Foo) + query = create_session(bind=testing.db).query(Foo) assert list(query.order_by([Foo.c.bar]).offset(10))[0].bar == 10 - + def test_offset(self): - query = create_session(bind=testbase.db).query(Foo) + query = create_session(bind=testing.db).query(Foo) assert len(list(query.limit(10))) == 10 class Obj1(object): @@ -124,7 +121,7 @@ class Obj1(object): class Obj2(object): pass -class GenerativeTest2(PersistTest): +class GenerativeTest2(TestBase): def setUpAll(self): global metadata, table1, table2 metadata = MetaData() @@ -137,24 +134,24 @@ class GenerativeTest2(PersistTest): ) mapper(Obj1, table1) mapper(Obj2, table2) - metadata.create_all(bind=testbase.db) - testbase.db.execute(table1.insert(), {'id':1},{'id':2},{'id':3},{'id':4}) - testbase.db.execute(table2.insert(), {'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\ + metadata.create_all(bind=testing.db) + testing.db.execute(table1.insert(), {'id':1},{'id':2},{'id':3},{'id':4}) + testing.db.execute(table2.insert(), {'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\ {'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3}) def tearDownAll(self): - metadata.drop_all(bind=testbase.db) + metadata.drop_all(bind=testing.db) clear_mappers() def test_distinctcount(self): - query = create_session(bind=testbase.db).query(Obj1) + query = create_session(bind=testing.db).query(Obj1) assert query.count() == 4 res = query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)) assert res.count() == 3 res = query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)).distinct() self.assertEqual(res.count(), 1) -class RelationsTest(AssertMixin): +class RelationsTest(TestBase, AssertsExecutionResults): def setUpAll(self): tables.create() tables.data() @@ -169,7 +166,7 @@ class RelationsTest(AssertMixin): 'items':relation(mapper(tables.Item, tables.orderitems)) })) }) - session = create_session(bind=testbase.db) + session = create_session(bind=testing.db) query = session.query(tables.User) x = query.join(['orders', 'items']).filter(tables.Item.c.item_id==2) print x.compile() @@ -181,7 +178,7 @@ class RelationsTest(AssertMixin): 'items':relation(mapper(tables.Item, tables.orderitems)) })) }) - session = create_session(bind=testbase.db) + session = create_session(bind=testing.db) query = session.query(tables.User) x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) print x.compile() @@ -193,7 +190,7 @@ class RelationsTest(AssertMixin): 'items':relation(mapper(tables.Item, tables.orderitems)) })) }) - session = create_session(bind=testbase.db) + session = create_session(bind=testing.db) query = session.query(tables.User) x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count() assert x==2 @@ -203,18 +200,18 @@ class RelationsTest(AssertMixin): 'items':relation(mapper(tables.Item, tables.orderitems)) })) }) - session = create_session(bind=testbase.db) + session = create_session(bind=testing.db) query = session.query(tables.User) - x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\ + x = query.select_from(tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)).\ filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) print x.compile() self.assert_result(list(x), tables.User, *tables.user_result[1:3]) - -class CaseSensitiveTest(PersistTest): + +class CaseSensitiveTest(TestBase): def setUpAll(self): global metadata, table1, table2 - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table1 = Table('Table1', metadata, Column('ID', Integer, primary_key=True), ) @@ -232,9 +229,9 @@ class CaseSensitiveTest(PersistTest): def tearDownAll(self): metadata.drop_all() clear_mappers() - + def test_distinctcount(self): - q = create_session(bind=testbase.db).query(Obj1) + q = create_session(bind=testing.db).query(Obj1) assert q.count() == 4 res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)) assert res.count() == 3 @@ -244,27 +241,25 @@ class CaseSensitiveTest(PersistTest): class SelfRefTest(ORMTest): def define_tables(self, metadata): global t1 - t1 = Table('t1', metadata, + t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('t1.id')) ) def test_noautojoin(self): class T(object):pass mapper(T, t1, properties={'children':relation(T)}) - sess = create_session(bind=testbase.db) - try: - sess.query(T).join('children').select_by(id=7) - assert False - except exceptions.InvalidRequestError, e: - assert str(e) == "Self-referential query on 'T.children (T)' property requires create_aliases=True argument.", str(e) - - try: + sess = create_session(bind=testing.db) + def go(): + sess.query(T).join('children') + self.assertRaisesMessage(exceptions.InvalidRequestError, + "Self-referential query on 'T\.children \(T\)' property requires aliased=True argument.", go) + + def go(): sess.query(T).join(['children']).select_by(id=7) - assert False - except exceptions.InvalidRequestError, e: - assert str(e) == "Self-referential query on 'T.children (T)' property requires create_aliases=True argument.", str(e) - - - + self.assertRaisesMessage(exceptions.InvalidRequestError, + "Self-referential query on 'T\.children \(T\)' property requires aliased=True argument.", go) + + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/inheritance/abc_inheritance.py b/test/orm/inheritance/abc_inheritance.py index 3b35b3713d..5f7a107562 100644 --- a/test/orm/inheritance/abc_inheritance.py +++ b/test/orm/inheritance/abc_inheritance.py @@ -1,9 +1,10 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE from testlib import * + def produce_test(parent, child, direction): """produce a testcase for A->B->C inheritance with a self-referential relationship between two of the classes, using either one-to-many or @@ -12,30 +13,30 @@ def produce_test(parent, child, direction): def define_tables(self, meta): global ta, tb, tc ta = ["a", meta] - ta.append(Column('id', Integer, primary_key=True)), + ta.append(Column('id', Integer, primary_key=True)), ta.append(Column('a_data', String(30))) if "a"== parent and direction == MANYTOONE: ta.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) elif "a" == child and direction == ONETOMANY: ta.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) ta = Table(*ta) - + tb = ["b", meta] tb.append(Column('id', Integer, ForeignKey("a.id"), primary_key=True, )) - + tb.append(Column('b_data', String(30))) - + if "b"== parent and direction == MANYTOONE: tb.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) elif "b" == child and direction == ONETOMANY: tb.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) tb = Table(*tb) - + tc = ["c", meta] tc.append(Column('id', Integer, ForeignKey("b.id"), primary_key=True, )) - + tc.append(Column('c_data', String(30))) - + if "c"== parent and direction == MANYTOONE: tc.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) elif "c" == child and direction == ONETOMANY: @@ -50,13 +51,13 @@ def produce_test(parent, child, direction): child_table = {"a":ta, "b":tb, "c": tc}[child] child_table.update(values={child_table.c.parent_id:None}).execute() super(ABCTest, self).tearDown() - + def test_roundtrip(self): parent_table = {"a":ta, "b":tb, "c": tc}[parent] child_table = {"a":ta, "b":tb, "c": tc}[child] remote_side = None - + if direction == MANYTOONE: foreign_keys = [parent_table.c.child_id] elif direction == ONETOMANY: @@ -110,7 +111,9 @@ def produce_test(parent, child, direction): somea = A('somea') someb = B('someb') somec = C('somec') - print "APPENDING", parent.__class__.__name__ , "TO", child.__class__.__name__ + + #print "APPENDING", parent.__class__.__name__ , "TO", child.__class__.__name__ + sess.save(parent_obj) parent_obj.collection.append(child_obj) if direction == ONETOMANY: @@ -137,20 +140,20 @@ def produce_test(parent, child, direction): result2 = sess.query(parent_class).get(parent2.id) assert result2.id == parent2.id assert result2.collection[0].id == child_obj.id - + sess.clear() # assert result via polymorphic load of parent object - result = sess.query(A).get_by(id=parent_obj.id) + result = sess.query(A).filter_by(id=parent_obj.id).one() assert result.id == parent_obj.id assert result.collection[0].id == child_obj.id if direction == ONETOMANY: assert result.collection[1].id == child2.id elif direction == MANYTOONE: - result2 = sess.query(A).get_by(id=parent2.id) + result2 = sess.query(A).filter_by(id=parent2.id).one() assert result2.id == parent2.id assert result2.collection[0].id == child_obj.id - + ABCTest.__name__ = "Test%sTo%s%s" % (parent, child, (direction is ONETOMANY and "O2M" or "M2O")) return ABCTest @@ -163,4 +166,4 @@ for parent in ["a", "b", "c"]: if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/inheritance/abc_polymorphic.py b/test/orm/inheritance/abc_polymorphic.py new file mode 100644 index 0000000000..076c7b76b8 --- /dev/null +++ b/test/orm/inheritance/abc_polymorphic.py @@ -0,0 +1,90 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy import exceptions, util +from sqlalchemy.orm import * +from testlib import * +from testlib import fixtures + +class ABCTest(ORMTest): + def define_tables(self, metadata): + global a, b, c + a = Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('adata', String(30)), + Column('type', String(30)), + ) + b = Table('b', metadata, + Column('id', Integer, ForeignKey('a.id'), primary_key=True), + Column('bdata', String(30))) + c = Table('c', metadata, + Column('id', Integer, ForeignKey('b.id'), primary_key=True), + Column('cdata', String(30))) + + def make_test(fetchtype): + def test_roundtrip(self): + class A(fixtures.Base):pass + class B(A):pass + class C(B):pass + + if fetchtype == 'union': + abc = a.outerjoin(b).outerjoin(c) + bc = a.join(b).outerjoin(c) + else: + abc = bc = None + + mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a', polymorphic_fetch=fetchtype) + mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b', polymorphic_fetch=fetchtype) + mapper(C, c, inherits=B, polymorphic_identity='c') + + a1 = A(adata='a1') + b1 = B(bdata='b1', adata='b1') + b2 = B(bdata='b2', adata='b2') + b3 = B(bdata='b3', adata='b3') + c1 = C(cdata='c1', bdata='c1', adata='c1') + c2 = C(cdata='c2', bdata='c2', adata='c2') + c3 = C(cdata='c2', bdata='c2', adata='c2') + + sess = create_session() + for x in (a1, b1, b2, b3, c1, c2, c3): + sess.save(x) + sess.flush() + sess.clear() + + #for obj in sess.query(A).all(): + # print obj + assert [ + A(adata='a1'), + B(bdata='b1', adata='b1'), + B(bdata='b2', adata='b2'), + B(bdata='b3', adata='b3'), + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(A).all() + + assert [ + B(bdata='b1', adata='b1'), + B(bdata='b2', adata='b2'), + B(bdata='b3', adata='b3'), + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(B).all() + + assert [ + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(C).all() + + test_roundtrip = _function_named( + test_roundtrip, 'test_%s' % fetchtype) + return test_roundtrip + + test_union = make_test('union') + test_select = make_test('select') + test_deferred = make_test('deferred') + + +if __name__ == '__main__': + testenv.main() diff --git a/test/orm/inheritance/alltests.py b/test/orm/inheritance/alltests.py index 1ab10c0607..e51297f8a6 100644 --- a/test/orm/inheritance/alltests.py +++ b/test/orm/inheritance/alltests.py @@ -1,19 +1,22 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest def suite(): modules_to_test = ( 'orm.inheritance.basic', + 'orm.inheritance.query', 'orm.inheritance.manytomany', 'orm.inheritance.single', 'orm.inheritance.concrete', 'orm.inheritance.polymorph', 'orm.inheritance.polymorph2', 'orm.inheritance.poly_linked_list', + 'orm.inheritance.abc_polymorphic', 'orm.inheritance.abc_inheritance', 'orm.inheritance.productspec', 'orm.inheritance.magazine', - + 'orm.inheritance.selects', + ) alltests = unittest.TestSuite() for name in modules_to_test: @@ -25,4 +28,4 @@ def suite(): if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index be623e1b87..8a0b6f30af 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -1,16 +1,17 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions, util from sqlalchemy.orm import * from testlib import * - +from testlib import fixtures class O2MTest(ORMTest): """deals with inheritance and one-to-many relationships""" def define_tables(self, metadata): global foo, bar, blub - # the 'data' columns are to appease SQLite which cant handle a blank INSERT foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), Column('data', String(20))) bar = Table('bar', metadata, @@ -35,7 +36,7 @@ class O2MTest(ORMTest): return "Bar id %d, data %s" % (self.id, self.data) mapper(Bar, bar, inherits=Foo) - + class Blub(Bar): def __repr__(self): return "Blub id %d, data %s" % (self.id, self.data) @@ -54,19 +55,145 @@ class O2MTest(ORMTest): b1.parent_foo = f b2.parent_foo = f sess.flush() - compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo) + compare = ','.join([repr(b1), repr(b2), repr(b1.parent_foo), repr(b2.parent_foo)]) sess.clear() - l = sess.query(Blub).select() - result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo) + l = sess.query(Blub).all() + result = ','.join([repr(l[0]), repr(l[1]), repr(l[0].parent_foo), repr(l[1].parent_foo)]) + print compare print result self.assert_(compare == result) self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') -class GetTest(ORMTest): +class FalseDiscriminatorTest(ORMTest): + def define_tables(self, metadata): + global t1 + t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', Integer, nullable=False)) + + def test_false_discriminator(self): + class Foo(object):pass + class Bar(Foo):pass + mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=1) + mapper(Bar, inherits=Foo, polymorphic_identity=0) + sess = create_session() + f1 = Bar() + sess.save(f1) + sess.flush() + assert f1.type == 0 + sess.clear() + assert isinstance(sess.query(Foo).one(), Bar) + +class PolymorphicSynonymTest(ORMTest): + def define_tables(self, metadata): + global t1, t2 + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('type', String(10), nullable=False), + Column('info', Text)) + t2 = Table('t2', metadata, + Column('id', Integer, ForeignKey('t1.id'), primary_key=True), + Column('data', String(10), nullable=False)) + + def test_polymorphic_synonym(self): + class T1(fixtures.Base): + def info(self): + return "THE INFO IS:" + self._info + def _set_info(self, x): + self._info = x + info = property(info, _set_info) + + class T2(T1):pass + + mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', properties={ + 'info':synonym('_info', map_column=True) + }) + mapper(T2, t2, inherits=T1, polymorphic_identity='t2') + sess = create_session() + at1 = T1(info='at1') + at2 = T2(info='at2', data='t2 data') + sess.save(at1) + sess.save(at2) + sess.flush() + sess.clear() + self.assertEquals(sess.query(T2).filter(T2.info=='at2').one(), at2) + self.assertEquals(at2.info, "THE INFO IS:at2") + + +class CascadeTest(ORMTest): + """that cascades on polymorphic relations continue + cascading along the path of the instance's mapper, not + the base mapper.""" + + def define_tables(self, metadata): + global t1, t2, t3, t4 + t1= Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)) + ) + + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('t1id', Integer, ForeignKey('t1.id')), + Column('type', String(30)), + Column('data', String(30)) + ) + t3 = Table('t3', metadata, + Column('id', Integer, ForeignKey('t2.id'), primary_key=True), + Column('moredata', String(30))) + + t4 = Table('t4', metadata, + Column('id', Integer, primary_key=True), + Column('t3id', Integer, ForeignKey('t3.id')), + Column('data', String(30))) + + def test_cascade(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + class T3(T2): + pass + class T4(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't2s':relation(T2, cascade="all") + }) + mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') + mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={ + 't4s':relation(T4, cascade="all") + }) + mapper(T4, t4) + + sess = create_session() + t1_1 = T1(data='t1') + + t3_1 = T3(data ='t3', moredata='t3') + t2_1 = T2(data='t2') + + t1_1.t2s.append(t2_1) + t1_1.t2s.append(t3_1) + + t4_1 = T4(data='t4') + t3_1.t4s.append(t4_1) + + sess.save(t1_1) + + + assert t4_1 in sess.new + sess.flush() + + sess.delete(t1_1) + assert t4_1 in sess.deleted + sess.flush() + + + +class GetTest(ORMTest): def define_tables(self, metadata): global foo, bar, blub foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), Column('type', String(30)), Column('data', String(20))) @@ -79,15 +206,15 @@ class GetTest(ORMTest): Column('foo_id', Integer, ForeignKey('foo.id')), Column('bar_id', Integer, ForeignKey('bar.id')), Column('data', String(20))) - - def create_test(polymorphic): + + def create_test(polymorphic, name): def test_get(self): class Foo(object): pass class Bar(Foo): pass - + class Blub(Bar): pass @@ -99,7 +226,7 @@ class GetTest(ORMTest): mapper(Foo, foo) mapper(Bar, bar, inherits=Foo) mapper(Blub, blub, inherits=Bar) - + sess = create_session() f = Foo() b = Bar() @@ -108,7 +235,7 @@ class GetTest(ORMTest): sess.save(b) sess.save(bl) sess.flush() - + if polymorphic: def go(): assert sess.query(Foo).get(f.id) == f @@ -117,41 +244,42 @@ class GetTest(ORMTest): assert sess.query(Bar).get(b.id) == b assert sess.query(Bar).get(bl.id) == bl assert sess.query(Blub).get(bl.id) == bl - - self.assert_sql_count(testbase.db, go, 0) + + self.assert_sql_count(testing.db, go, 0) else: - # this is testing the 'wrong' behavior of using get() + # this is testing the 'wrong' behavior of using get() # polymorphically with mappers that are not configured to be # polymorphic. the important part being that get() always # returns an instance of the query's type. def go(): assert sess.query(Foo).get(f.id) == f - + bb = sess.query(Foo).get(b.id) assert isinstance(b, Foo) and bb.id==b.id - + bll = sess.query(Foo).get(bl.id) assert isinstance(bll, Foo) and bll.id==bl.id - + assert sess.query(Bar).get(b.id) == b - + bll = sess.query(Bar).get(bl.id) assert isinstance(bll, Bar) and bll.id == bl.id - + assert sess.query(Blub).get(bl.id) == bl - - self.assert_sql_count(testbase.db, go, 3) - + + self.assert_sql_count(testing.db, go, 3) + + test_get = _function_named(test_get, name) return test_get - - test_get_polymorphic = create_test(True) - test_get_nonpolymorphic = create_test(False) + + test_get_polymorphic = create_test(True, 'test_get_polymorphic') + test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic') class ConstructionTest(ORMTest): def define_tables(self, metadata): global content_type, content, product - content_type = Table('content_type', metadata, + content_type = Table('content_type', metadata, Column('id', Integer, primary_key=True) ) content = Table('content', metadata, @@ -159,7 +287,7 @@ class ConstructionTest(ORMTest): Column('content_type_id', Integer, ForeignKey('content_type.id')), Column('type', String(30)) ) - product = Table('product', metadata, + product = Table('product', metadata, Column('id', Integer, ForeignKey('content.id'), primary_key=True) ) @@ -169,14 +297,10 @@ class ConstructionTest(ORMTest): class Product(Content): pass content_types = mapper(ContentType, content_type) - contents = mapper(Content, content, properties={ - 'content_type':relation(content_types) - }, polymorphic_identity='contents') - - products = mapper(Product, product, inherits=contents, polymorphic_identity='products') - try: - compile_mappers() + contents = mapper(Content, content, properties={ + 'content_type':relation(content_types) + }, polymorphic_identity='contents') assert False except exceptions.ArgumentError, e: assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" @@ -195,22 +319,26 @@ class ConstructionTest(ORMTest): p = Product() p.contenttype = ContentType() # TODO: assertion ?? - + class EagerLazyTest(ORMTest): """tests eager load/lazy load of child items off inheritance mappers, tests that LazyLoader constructs the right query condition.""" def define_tables(self, metadata): global foo, bar, bar_foo - foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq'), primary_key=True), - Column('data', String(30))) - bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('data', String(30))) + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), + Column('data', String(30))) + bar = Table('bar', metadata, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(30))) bar_foo = Table('bar_foo', metadata, - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('foo_id', Integer, ForeignKey('foo.id')) + Column('bar_id', Integer, ForeignKey('bar.id')), + Column('foo_id', Integer, ForeignKey('foo.id')) ) + @testing.fails_on('maxdb') def testbasic(self): class Foo(object): pass class Bar(Foo): pass @@ -231,11 +359,11 @@ class EagerLazyTest(ORMTest): bar_foo.insert().execute(bar_id=1, foo_id=3) bar_foo.insert().execute(bar_id=2, foo_id=4) - + sess = create_session() q = sess.query(Bar) - self.assert_(len(q.selectfirst().lazy) == 1) - self.assert_(len(q.selectfirst().eager) == 1) + self.assert_(len(q.first().lazy) == 1) + self.assert_(len(q.first().eager) == 1) class FlushTest(ORMTest): @@ -250,7 +378,7 @@ class FlushTest(ORMTest): roles = Table('role', metadata, Column('id', Integer, primary_key=True), - Column('description', String(32)) + Column('description', String(32)) ) user_roles = Table('user_role', metadata, @@ -262,19 +390,19 @@ class FlushTest(ORMTest): Column('admin_id', Integer, primary_key=True), Column('user_id', Integer, ForeignKey('users.id')) ) - + def testone(self): class User(object):pass class Role(object):pass class Admin(User):pass role_mapper = mapper(Role, roles) user_mapper = mapper(User, users, properties = { - 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False) + 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False) } ) admin_mapper = mapper(Admin, admins, inherits=user_mapper) sess = create_session() - adminrole = Role('admin') + adminrole = Role() sess.save(adminrole) sess.flush() @@ -287,7 +415,7 @@ class FlushTest(ORMTest): a.password = 'admin' sess.save(a) sess.flush() - + assert user_roles.count().scalar() == 1 def testtwo(self): @@ -308,7 +436,7 @@ class FlushTest(ORMTest): } ) - admin_mapper = mapper(Admin, admins, inherits=user_mapper) + admin_mapper = mapper(Admin, admins, inherits=user_mapper) # create roles adminrole = Role('admin') @@ -327,6 +455,109 @@ class FlushTest(ORMTest): sess.flush() assert user_roles.count().scalar() == 1 +class VersioningTest(ORMTest): + def define_tables(self, metadata): + global base, subtable, stuff + base = Table('base', metadata, + Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ), + Column('version_id', Integer, nullable=False), + Column('value', String(40)), + Column('discriminator', Integer, nullable=False) + ) + subtable = Table('subtable', metadata, + Column('id', None, ForeignKey('base.id'), primary_key=True), + Column('subdata', String(50)) + ) + stuff = Table('stuff', metadata, + Column('id', Integer, primary_key=True), + Column('parent', Integer, ForeignKey('base.id')) + ) + + @engines.close_open_connections + def test_save_update(self): + class Base(fixtures.Base): + pass + class Sub(Base): + pass + class Stuff(Base): + pass + mapper(Stuff, stuff) + mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={ + 'stuff':relation(Stuff) + }) + mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) + + sess = create_session() + + b1 = Base(value='b1') + s1 = Sub(value='sub1', subdata='some subdata') + sess.save(b1) + sess.save(s1) + + sess.flush() + + sess2 = create_session() + s2 = sess2.query(Base).get(s1.id) + s2.subdata = 'sess2 subdata' + + s1.subdata = 'sess1 subdata' + + sess.flush() + + try: + sess2.query(Base).with_lockmode('read').get(s1.id) + assert False + except exceptions.ConcurrentModificationError, e: + assert True + + try: + sess2.flush() + assert False + except exceptions.ConcurrentModificationError, e: + assert True + + sess2.refresh(s2) + assert s2.subdata == 'sess1 subdata' + s2.subdata = 'sess2 subdata' + sess2.flush() + + def test_delete(self): + class Base(fixtures.Base): + pass + class Sub(Base): + pass + + mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1) + mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) + + sess = create_session() + + b1 = Base(value='b1') + s1 = Sub(value='sub1', subdata='some subdata') + s2 = Sub(value='sub2', subdata='some other subdata') + sess.save(b1) + sess.save(s1) + sess.save(s2) + + sess.flush() + + sess2 = create_session() + s3 = sess2.query(Base).get(s1.id) + sess2.delete(s3) + sess2.flush() + + s2.subdata = 'some new subdata' + sess.flush() + + try: + s1.subdata = 'some new subdata' + sess.flush() + assert False + except exceptions.ConcurrentModificationError, e: + assert True + + + class DistinctPKTest(ORMTest): """test the construction of mapper.primary_key when an inheriting relationship joins on a column other than primary key column.""" @@ -352,9 +583,6 @@ class DistinctPKTest(ORMTest): class Employee(Person): pass - import warnings - warnings.filterwarnings("error", r".*On mapper.*distinct primary key") - def insert_data(self): person_insert = person_table.insert() person_insert.execute(id=1, name='alice') @@ -376,12 +604,12 @@ class DistinctPKTest(ORMTest): def test_explicit_composite_pk(self): person_mapper = mapper(Person, person_table) - mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id]) try: + mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id]) self._do_test(True) assert False - except RuntimeWarning, e: - assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name." + except exceptions.SAWarning, e: + assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e) def test_explicit_pk(self): person_mapper = mapper(Person, person_table) @@ -404,6 +632,73 @@ class DistinctPKTest(ORMTest): assert alice1.name == alice2.name == 'alice' assert bob.name == 'bob' +class SyncCompileTest(ORMTest): + """test that syncrules compile properly on custom inherit conds""" + def define_tables(self, metadata): + global _a_table, _b_table, _c_table + + _a_table = Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('data1', String(128)) + ) + + _b_table = Table('b', metadata, + Column('a_id', Integer, ForeignKey('a.id'), primary_key=True), + Column('data2', String(128)) + ) + + _c_table = Table('c', metadata, + # Column('a_id', Integer, ForeignKey('b.a_id'), primary_key=True), #works + Column('b_a_id', Integer, ForeignKey('b.a_id'), primary_key=True), + Column('data3', String(128)) + ) + + def test_joins(self): + for j1 in (None, _b_table.c.a_id==_a_table.c.id, _a_table.c.id==_b_table.c.a_id): + for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id, _c_table.c.b_a_id==_b_table.c.a_id): + self._do_test(j1, j2) + for t in _a_table.metadata.table_iterator(reverse=True): + t.delete().execute().close() + + def _do_test(self, j1, j2): + class A(object): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + class B(A): + pass + + class C(B): + pass + + mapper(A, _a_table) + mapper(B, _b_table, inherits=A, + inherit_condition=j1 + ) + mapper(C, _c_table, inherits=B, + inherit_condition=j2 + ) + + session = create_session() + + a = A(data1='a1') + session.save(a) + + b = B(data1='b1', data2='b2') + session.save(b) + + c = C(data1='c1', data2='c2', data3='c3') + session.save(c) + + session.flush() + session.clear() + + assert len(session.query(A).all()) == 3 + assert len(session.query(B).all()) == 2 + assert len(session.query(C).all()) == 1 + + -if __name__ == "__main__": - testbase.main() +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/concrete.py index d95a96da5f..29fa1df605 100644 --- a/test/orm/inheritance/concrete.py +++ b/test/orm/inheritance/concrete.py @@ -1,24 +1,39 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * -class ConcreteTest1(ORMTest): +class ConcreteTest(ORMTest): def define_tables(self, metadata): - global managers_table, engineers_table - managers_table = Table('managers', metadata, + global managers_table, engineers_table, hackers_table, companies + + companies = Table('companies', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50))) + + managers_table = Table('managers', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), Column('manager_data', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) + ) + + engineers_table = Table('engineers', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('engineer_info', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) ) - engineers_table = Table('engineers', metadata, + hackers_table = Table('hackers', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), Column('engineer_info', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')), + Column('nickname', String(50)) ) - def testbasic(self): + def test_basic(self): class Employee(object): def __init__(self, name): self.name = name @@ -54,16 +69,126 @@ class ConcreteTest1(ORMTest): session.flush() session.clear() - print set([repr(x) for x in session.query(Employee).select()]) - assert set([repr(x) for x in session.query(Employee).select()]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Manager).select()]) == set(["Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Engineer).select()]) == set(["Engineer Kurt knows how to hack"]) + print set([repr(x) for x in session.query(Employee).all()]) + assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Kurt knows how to hack"]) + + def test_multi_level(self): + class Employee(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return self.__class__.__name__ + " " + self.name + + class Manager(Employee): + def __init__(self, name, manager_data): + self.name = name + self.manager_data = manager_data + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.manager_data + + class Engineer(Employee): + def __init__(self, name, engineer_info): + self.name = name + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.engineer_info + + class Hacker(Engineer): + def __init__(self, name, nickname, engineer_info): + self.name = name + self.nickname = nickname + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " '" + \ + self.nickname + "' " + self.engineer_info + + pjoin = polymorphic_union({ + 'manager': managers_table, + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin') + + pjoin2 = polymorphic_union({ + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin2') + + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper(Manager, managers_table, + inherits=employee_mapper, concrete=True, + polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, + with_polymorphic=('*', pjoin2), + polymorphic_on=pjoin2.c.type, + inherits=employee_mapper, concrete=True, + polymorphic_identity='engineer') + hacker_mapper = mapper(Hacker, hackers_table, + inherits=engineer_mapper, + concrete=True, polymorphic_identity='hacker') + + session = create_session() + session.save(Manager('Tom', 'knows how to manage things')) + session.save(Engineer('Jerry', 'knows how to program')) + session.save(Hacker('Kurt', 'Badass', 'knows how to hack')) + session.flush() + session.clear() + + assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + + def test_relation(self): + class Employee(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return self.__class__.__name__ + " " + self.name + + class Manager(Employee): + def __init__(self, name, manager_data): + self.name = name + self.manager_data = manager_data + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.manager_data + + class Engineer(Employee): + def __init__(self, name, engineer_info): + self.name = name + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.engineer_info + + class Company(object): + pass + + pjoin = polymorphic_union({ + 'manager':managers_table, + 'engineer':engineers_table + }, 'type', 'pjoin') + + mapper(Company, companies, properties={ + 'employees':relation(Employee, lazy=False) + }) + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer') + + session = create_session() + c = Company() + c.employees.append(Manager('Tom', 'knows how to manage things')) + c.employees.append(Engineer('Kurt', 'knows how to hack')) + session.save(c) + session.flush() + session.clear() + + def go(): + c2 = session.query(Company).get(c.id) + assert set([repr(x) for x in c2.employees]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) + self.assert_sql_count(testing.db, go, 1) - def testwithrelation(self): - pass - - # TODO: test a self-referential relationship on a concrete polymorphic mapping if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/orm/inheritance/magazine.py b/test/orm/inheritance/magazine.py index a0bf241485..621f9639f4 100644 --- a/test/orm/inheritance/magazine.py +++ b/test/orm/inheritance/magazine.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * @@ -51,7 +51,7 @@ class LocationName(BaseObject): class PageSize(BaseObject): def __repr__(self): return "%s(%sx%s, %s)" % (self.__class__.__name__, self.width, self.height, self.name) - + class Magazine(BaseObject): def __repr__(self): return "%s(%s, %s)" % (self.__class__.__name__, repr(self.location), repr(self.size)) @@ -130,12 +130,10 @@ def generate_round_trip_test(use_unions=False, use_joins=False): location_name_mapper = mapper(LocationName, location_name_table) location_mapper = mapper(Location, location_table, properties = { - 'issue': relation(Issue, backref='locations'), + 'issue': relation(Issue, backref=backref('locations', lazy=False, cascade="all, delete-orphan")), '_name': relation(LocationName), }) - issue_mapper.add_property('locations', relation(Location, lazy=False, private=True, backref='issue')) - page_size_mapper = mapper(PageSize, page_size_table) magazine_mapper = mapper(Magazine, magazine_table, properties = { @@ -196,7 +194,7 @@ def generate_round_trip_test(use_unions=False, use_joins=False): page2 = MagazinePage(magazine=magazine,page_no=2) page3 = ClassifiedPage(magazine=magazine,page_no=3) session.save(pub) - + session.flush() print [x for x in session] session.clear() @@ -208,13 +206,14 @@ def generate_round_trip_test(use_unions=False, use_joins=False): print p.issues[0].locations[0].magazine.pages print [page, page2, page3] assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages) - - test_roundtrip.__name__ = "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions") + + test_roundtrip = _function_named( + test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions")) setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip) - + for (use_union, use_join) in [(True, False), (False, True), (False, False)]: generate_round_trip_test(use_union, use_join) - + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/manytomany.py index df00f39d0b..f1bbc2ae77 100644 --- a/test/orm/inheritance/manytomany.py +++ b/test/orm/inheritance/manytomany.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * @@ -12,35 +12,28 @@ class InheritTest(ORMTest): global groups global user_group_map - principals = Table( - 'principals', - metadata, - Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True), - Column('name', String(50), nullable=False), - ) + principals = Table('principals', metadata, + Column('principal_id', Integer, + Sequence('principal_id_seq', optional=False), + primary_key=True), + Column('name', String(50), nullable=False)) - users = Table( - 'prin_users', - metadata, - Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), + users = Table('prin_users', metadata, + Column('principal_id', Integer, + ForeignKey('principals.principal_id'), primary_key=True), Column('password', String(50), nullable=False), Column('email', String(50), nullable=False), - Column('login_id', String(50), nullable=False), - - ) + Column('login_id', String(50), nullable=False)) - groups = Table( - 'prin_groups', - metadata, - Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), + groups = Table('prin_groups', metadata, + Column('principal_id', Integer, + ForeignKey('principals.principal_id'), primary_key=True)) - ) - - user_group_map = Table( - 'prin_user_group_map', - metadata, - Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ), - Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ), + user_group_map = Table('prin_user_group_map', metadata, + Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), + primary_key=True ), + Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), + primary_key=True ), ) def testbasic(self): @@ -56,18 +49,12 @@ class InheritTest(ORMTest): pass mapper(Principal, principals) - mapper( - User, - users, - inherits=Principal - ) + mapper(User, users, inherits=Principal) - mapper( - Group, - groups, - inherits=Principal, - properties=dict( users = relation(User, secondary=user_group_map, lazy=True, backref="groups") ) - ) + mapper(Group, groups, inherits=Principal, properties={ + 'users': relation(User, secondary=user_group_map, + lazy=True, backref="groups") + }) g = Group(name="group1") g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1")) @@ -75,13 +62,14 @@ class InheritTest(ORMTest): sess.save(g) sess.flush() # TODO: put an assertion - + class InheritTest2(ORMTest): """deals with inheritance and many-to-many relationships""" def define_tables(self, metadata): global foo, bar, foo_bar foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_id_seq'), primary_key=True), + Column('id', Integer, Sequence('foo_id_seq', optional=True), + primary_key=True), Column('data', String(20)), ) @@ -95,11 +83,11 @@ class InheritTest2(ORMTest): Column('bar_id', Integer, ForeignKey('bar.bid'))) def testget(self): - class Foo(object):pass - def __init__(self, data=None): - self.data = data + class Foo(object): + def __init__(self, data=None): + self.data = data class Bar(Foo):pass - + mapper(Foo, foo) mapper(Bar, bar, inherits=Foo) print foo.join(bar).primary_key @@ -109,13 +97,13 @@ class InheritTest2(ORMTest): sess.save(b) sess.flush() sess.clear() - + # test that "bar.bid" does not need to be referenced in a get # (ticket 185) assert sess.query(Bar).get(b.id).id == b.id - + def testbasic(self): - class Foo(object): + class Foo(object): def __init__(self, data=None): self.data = data @@ -126,7 +114,7 @@ class InheritTest2(ORMTest): mapper(Bar, bar, inherits=Foo, properties={ 'foos': relation(Foo, secondary=foo_bar, lazy=False) }) - + sess = create_session() b = Bar('barfoo') sess.save(b) @@ -140,10 +128,10 @@ class InheritTest2(ORMTest): sess.flush() sess.clear() - l = sess.query(Bar).select() + l = sess.query(Bar).all() print l[0] print l[0].foos - self.assert_result(l, Bar, + self.assert_unordered_result(l, Bar, # {'id':1, 'data':'barfoo', 'bid':1, 'foos':(Foo, [{'id':2,'data':'subfoo1'}, {'id':3,'data':'subfoo2'}])}, {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])}, ) @@ -155,7 +143,8 @@ class InheritTest3(ORMTest): # the 'data' columns are to appease SQLite which cant handle a blank INSERT foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), Column('data', String(20))) bar = Table('bar', metadata, @@ -166,10 +155,10 @@ class InheritTest3(ORMTest): Column('id', Integer, ForeignKey('bar.id'), primary_key=True), Column('data', String(20))) - bar_foo = Table('bar_foo', metadata, + bar_foo = Table('bar_foo', metadata, Column('bar_id', Integer, ForeignKey('bar.id')), Column('foo_id', Integer, ForeignKey('foo.id'))) - + blub_bar = Table('bar_blub', metadata, Column('blub_id', Integer, ForeignKey('blub.id')), Column('bar_id', Integer, ForeignKey('bar.id'))) @@ -177,7 +166,7 @@ class InheritTest3(ORMTest): blub_foo = Table('blub_foo', metadata, Column('blub_id', Integer, ForeignKey('blub.id')), Column('foo_id', Integer, ForeignKey('foo.id'))) - + def testbasic(self): class Foo(object): def __init__(self, data=None): @@ -189,7 +178,7 @@ class InheritTest3(ORMTest): class Bar(Foo): def __repr__(self): return "Bar id %d, data %s" % (self.id, self.data) - + mapper(Bar, bar, inherits=Foo, properties={ 'foos' :relation(Foo, secondary=bar_foo, lazy=True) }) @@ -200,13 +189,15 @@ class InheritTest3(ORMTest): b.foos.append(Foo("foo #1")) b.foos.append(Foo("foo #2")) sess.flush() - compare = repr(b) + repr(b.foos) + compare = repr(b) + repr(sorted([repr(o) for o in b.foos])) sess.clear() - l = sess.query(Bar).select() + l = sess.query(Bar).all() print repr(l[0]) + repr(l[0].foos) - self.assert_(repr(l[0]) + repr(l[0].foos) == compare) - - def testadvanced(self): + found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos])) + self.assertEqual(found, compare) + + @testing.fails_on('maxdb') + def testadvanced(self): class Foo(object): def __init__(self, data=None): self.data = data @@ -222,7 +213,7 @@ class InheritTest3(ORMTest): class Blub(Bar): def __repr__(self): return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos])) - + mapper(Blub, blub, inherits=Bar, properties={ 'bars':relation(Bar, secondary=blub_bar, lazy=False), 'foos':relation(Foo, secondary=blub_foo, lazy=False), @@ -242,14 +233,14 @@ class InheritTest3(ORMTest): blubid = bl1.id sess.clear() - l = sess.query(Blub).select() + l = sess.query(Blub).all() print l self.assert_(repr(l[0]) == compare) sess.clear() - x = sess.query(Blub).get_by(id=blubid) + x = sess.query(Blub).filter_by(id=blubid).one() print x self.assert_(repr(x) == compare) - - -if __name__ == "__main__": - testbase.main() + + +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/inheritance/poly_linked_list.py b/test/orm/inheritance/poly_linked_list.py index 7297002f52..b2dd6c658e 100644 --- a/test/orm/inheritance/poly_linked_list.py +++ b/test/orm/inheritance/poly_linked_list.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * @@ -23,45 +23,45 @@ class PolymorphicCircularTest(ORMTest): Column('id', Integer, ForeignKey('table1.id'), primary_key=True), ) - data = Table('data', metadata, + data = Table('data', metadata, Column('id', Integer, primary_key=True), Column('node_id', Integer, ForeignKey('table1.id')), Column('data', String(30)) ) - + #join = polymorphic_union( # { # 'table3' : table1.join(table3), # 'table2' : table1.join(table2), - # 'table1' : table1.select(table1.c.type.in_('table1', 'table1b')), + # 'table1' : table1.select(table1.c.type.in_(['table1', 'table1b'])), # }, None, 'pjoin') - + join = table1.outerjoin(table2).outerjoin(table3).alias('pjoin') #join = None - + class Table1(object): def __init__(self, name, data=None): self.name = name if data is not None: self.data = data def __repr__(self): - return "%s(%d, %s, %s)" % (self.__class__.__name__, self.id, repr(str(self.name)), repr(self.data)) + return "%s(%s, %s, %s)" % (self.__class__.__name__, self.id, repr(str(self.name)), repr(self.data)) class Table1B(Table1): pass - + class Table2(Table1): pass class Table3(Table1): pass - + class Data(object): def __init__(self, data): self.data = data def __repr__(self): - return "%s(%d, %s)" % (self.__class__.__name__, self.id, repr(str(self.data))) - + return "%s(%s, %s)" % (self.__class__.__name__, self.id, repr(str(self.data))) + try: # this is how the mapping used to work. ensure that this raises an error now table1_mapper = mapper(Table1, table1, @@ -69,8 +69,8 @@ class PolymorphicCircularTest(ORMTest): polymorphic_on=table1.c.type, polymorphic_identity='table1', properties={ - 'next': relation(Table1, - backref=backref('prev', primaryjoin=join.c.id==join.c.related_id, foreignkey=join.c.id, uselist=False), + 'next': relation(Table1, + backref=backref('prev', primaryjoin=join.c.id==join.c.related_id, foreignkey=join.c.id, uselist=False), uselist=False, primaryjoin=join.c.id==join.c.related_id), 'data':relation(mapper(Data, data)) } @@ -80,10 +80,10 @@ class PolymorphicCircularTest(ORMTest): except: assert True clear_mappers() - + # currently, the "eager" relationships degrade to lazy relationships # due to the polymorphic load. - # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" + # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" # exception now. since eager loading would never work for that relation anyway, its better that the user # gets an exception instead of it silently not eager loading. table1_mapper = mapper(Table1, table1, @@ -91,8 +91,8 @@ class PolymorphicCircularTest(ORMTest): polymorphic_on=table1.c.type, polymorphic_identity='table1', properties={ - 'next': relation(Table1, - backref=backref('prev', primaryjoin=table1.c.id==table1.c.related_id, remote_side=table1.c.id, uselist=False), + 'next': relation(Table1, + backref=backref('prev', primaryjoin=table1.c.id==table1.c.related_id, remote_side=table1.c.id, uselist=False), uselist=False, primaryjoin=table1.c.id==table1.c.related_id), 'data':relation(mapper(Data, data), lazy=False) } @@ -105,27 +105,31 @@ class PolymorphicCircularTest(ORMTest): polymorphic_identity='table2') table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3') - + table1_mapper.compile() assert table1_mapper.primary_key == [table1.c.id], table1_mapper.primary_key - + + @testing.fails_on('maxdb') def testone(self): self.do_testlist([Table1, Table2, Table1, Table2]) + @testing.fails_on('maxdb') def testtwo(self): self.do_testlist([Table3]) - + + @testing.fails_on('maxdb') def testthree(self): self.do_testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1]) + @testing.fails_on('maxdb') def testfour(self): self.do_testlist([ - Table2('t2', [Data('data1'), Data('data2')]), + Table2('t2', [Data('data1'), Data('data2')]), Table1('t1', []), Table3('t3', [Data('data3')]), Table1B('t1b', [Data('data4'), Data('data5')]) ]) - + def do_testlist(self, classes): sess = create_session( ) @@ -147,7 +151,7 @@ class PolymorphicCircularTest(ORMTest): # save to DB sess.save(t) sess.flush() - + # string version of the saved list assertlist = [] node = t @@ -162,7 +166,7 @@ class PolymorphicCircularTest(ORMTest): # clear and query forwards sess.clear() - node = sess.query(Table1).selectfirst(Table1.c.id==t.id) + node = sess.query(Table1).filter(Table1.c.id==t.id).first() assertlist = [] while (node): assertlist.append(node) @@ -174,7 +178,7 @@ class PolymorphicCircularTest(ORMTest): # clear and query backwards sess.clear() - node = sess.query(Table1).selectfirst(Table1.c.id==obj.id) + node = sess.query(Table1).filter(Table1.c.id==obj.id).first() assertlist = [] while (node): assertlist.insert(0, node) @@ -183,7 +187,7 @@ class PolymorphicCircularTest(ORMTest): assert n.next is node node = n backwards = repr(assertlist) - + # everything should match ! print "ORIGNAL", original print "BACKWARDS",backwards @@ -191,4 +195,4 @@ class PolymorphicCircularTest(ORMTest): assert original == forwards == backwards if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py index 3eb2e032f0..5442520242 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/polymorph.py @@ -1,120 +1,61 @@ """tests basic polymorphic mapper loading/saving, minimal relations""" -import testbase +import testenv; testenv.configure_for_tests() import sets from sqlalchemy import * from sqlalchemy.orm import * +from sqlalchemy import exceptions from testlib import * +from testlib import fixtures - -class Person(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - def get_name(self): - try: - return getattr(self, 'person_name') - except AttributeError: - return getattr(self, 'name') - def __repr__(self): - return "Ordinary person %s" % self.get_name() +class Person(fixtures.Base): + pass class Engineer(Person): - def __repr__(self): - return "Engineer %s, status %s, engineer_name %s, primary_language %s" % (self.get_name(), self.status, self.engineer_name, self.primary_language) + pass class Manager(Person): - def __repr__(self): - return "Manager %s, status %s, manager_name %s" % (self.get_name(), self.status, self.manager_name) + pass class Boss(Manager): - def __repr__(self): - return "Boss %s, status %s, manager_name %s golf swing %s" % (self.get_name(), self.status, self.manager_name, self.golf_swing) - -class Company(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - def __repr__(self): - return "Company %s" % self.name + pass +class Company(fixtures.Base): + pass class PolymorphTest(ORMTest): def define_tables(self, metadata): global companies, people, engineers, managers, boss - - # a table to store companies - companies = Table('companies', metadata, - Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key=True), + + companies = Table('companies', metadata, + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) - # we will define an inheritance relationship between the table "people" and "engineers", - # and a second inheritance relationship between the table "people" and "managers" - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + people = Table('people', metadata, + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('company_id', Integer, ForeignKey('companies.company_id')), Column('name', String(50)), Column('type', String(30))) - engineers = Table('engineers', metadata, + engineers = Table('engineers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('status', String(30)), Column('engineer_name', String(50)), Column('primary_language', String(50)), ) - managers = Table('managers', metadata, + managers = Table('managers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('status', String(30)), Column('manager_name', String(50)) ) - boss = Table('boss', metadata, + boss = Table('boss', metadata, Column('boss_id', Integer, ForeignKey('managers.person_id'), primary_key=True), Column('golf_swing', String(30)), ) - - metadata.create_all() - -class CompileTest(PolymorphTest): - def testcompile(self): - person_join = polymorphic_union( { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - 'person':people.select(people.c.type=='person'), - }, None, 'pjoin') - person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person') - mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - - session = create_session() - session.save(Manager(name='Tom', status='knows how to manage things')) - session.save(Engineer(name='Kurt', status='knows how to hack')) - session.flush() - print session.query(Engineer).select() - - print session.query(Person).select() - - def testcompile2(self): - """test that a mapper can reference a property whose mapper inherits from this one.""" - person_join = polymorphic_union( { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - 'person':people.select(people.c.type=='person'), - }, None, 'pjoin') - - - person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, - polymorphic_identity='person', - properties = dict(managers = relation(Manager, lazy=True)) - ) - - mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - - #person_mapper.compile() - class_mapper(Manager).compile() + metadata.create_all() class InsertOrderTest(PolymorphTest): def test_insert_order(self): - """test that classes of multiple types mix up mapper inserts + """test that classes of multiple types mix up mapper inserts so that insert order of individual tables is maintained""" person_join = polymorphic_union( { @@ -128,7 +69,9 @@ class InsertOrderTest(PolymorphTest): mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') mapper(Company, companies, properties={ - 'employees': relation(Person, private=True, backref='company', order_by=person_join.c.person_id) + 'employees': relation(Person, + backref='company', + order_by=person_join.c.person_id) }) session = create_session() @@ -141,34 +84,30 @@ class InsertOrderTest(PolymorphTest): session.save(c) session.flush() session.clear() - c = session.query(Company).get(c.company_id) - for e in c.employees: - print e, e._instance_key, e.company - - assert [e.get_name() for e in c.employees] == ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith'] + self.assertEquals(session.query(Company).get(c.company_id), c) class RelationToSubclassTest(PolymorphTest): - def testrelationtosubclass(self): + def test_basic(self): """test a relation to an inheriting mapper where the relation is to a subclass - but the join condition is expressed by the parent table. - + but the join condition is expressed by the parent table. + also test that backrefs work in this case. - + this test touches upon a lot of the join/foreign key determination code in properties.py - and creates the need for properties.py to search for conditions individually within + and creates the need for properties.py to search for conditions individually within the mapper's local table as well as the mapper's 'mapped' table, so that relations requiring lots of specificity (like self-referential joins) as well as relations requiring more generalization (like the example here) both come up with proper results.""" - + mapper(Person, people) - + mapper(Engineer, engineers, inherits=Person) mapper(Manager, managers, inherits=Person) mapper(Company, companies, properties={ - 'managers': relation(Manager, lazy=True,backref="company") + 'managers': relation(Manager, backref="company") }) - + sess = create_session() c = Company(name='company1') @@ -177,23 +116,22 @@ class RelationToSubclassTest(PolymorphTest): sess.flush() sess.clear() - sess.query(Company).get_by(company_id=c.company_id) - assert sets.Set([e.get_name() for e in c.managers]) == sets.Set(['pointy haired boss']) + self.assertEquals(sess.query(Company).filter_by(company_id=c.company_id).one(), c) assert c.managers[0].company is c class RoundTripTest(PolymorphTest): pass - + def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, polymorphic_fetch=None, use_outer_joins=False): """generates a round trip test. - + include_base - whether or not to include the base 'person' type in the union. lazy_relation - whether or not the Company relation to People is lazy or eager. redefine_colprop - if we redefine the 'name' column to be 'people_name' on the base Person class use_literal_join - primary join condition is explicitly specified """ def test_roundtrip(self): - # create a union that represents both types of joins. + # create a union that represents both types of joins. if not polymorphic_fetch == 'union': person_join = None manager_join = None @@ -208,7 +146,7 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co 'manager':people.join(managers), 'person':people.select(people.c.type=='person'), }, None, 'pjoin') - + manager_join = people.join(managers).outerjoin(boss) else: if use_outer_joins: @@ -226,98 +164,143 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name}) else: person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person') - + mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') mapper(Manager, managers, inherits=person_mapper, select_table=manager_join, polymorphic_identity='manager') mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss') - + if use_literal_join: mapper(Company, companies, properties={ - 'employees': relation(Person, lazy=lazy_relation, primaryjoin=people.c.company_id==companies.c.company_id, private=True, - backref="company" + 'employees': relation(Person, lazy=lazy_relation, + primaryjoin=(people.c.company_id == + companies.c.company_id), + cascade="all,delete-orphan", + backref="company", + order_by=people.c.person_id ) }) else: mapper(Company, companies, properties={ - 'employees': relation(Person, lazy=lazy_relation, private=True, - backref="company" + 'employees': relation(Person, lazy=lazy_relation, + cascade="all, delete-orphan", + backref="company", order_by=people.c.person_id ) }) - + if redefine_colprop: person_attribute_name = 'person_name' else: person_attribute_name = 'name' - + + employees = [ + Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'}), + Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'}), + ] + if include_base: + employees.append(Person(**{person_attribute_name:'joesmith'})) + employees += [ + Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'}), + Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'}) + ] + + pointy = employees[0] + jsmith = employees[-1] + dilbert = employees[1] + session = create_session() c = Company(name='company1') - c.employees.append(Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'})) - c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'})) - if include_base: - c.employees.append(Person(status='HHH', **{person_attribute_name:'joesmith'})) - c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'})) - c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'})) + c.employees = employees session.save(c) - print session.new + session.flush() session.clear() - id = c.company_id - c = session.query(Company).get(id) - for e in c.employees: - print e, e._instance_key, e.company - if include_base: - assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')]) - else: - assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')]) - print "\n" + + self.assertEquals(session.query(Person).get(dilbert.person_id), dilbert) + session.clear() + + self.assertEquals(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert) + session.clear() + def go(): + cc = session.query(Company).get(c.company_id) + for e in cc.employees: + assert e._instance_key[0] == Person + self.assertEquals(cc.employees, employees) + + if not lazy_relation: + if polymorphic_fetch=='union': + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 5) + + else: + if polymorphic_fetch=='union': + self.assert_sql_count(testing.db, go, 2) + else: + self.assert_sql_count(testing.db, go, 6) + # test selecting from the query, using the base mapped table (people) as the selection criterion. # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" - dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() - dilbert2 = session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first() - assert dilbert is dilbert2 - + self.assertEquals( + session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first(), + dilbert + ) + self.assertEquals( + session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first(), + dilbert + ) + # test selecting from the query, joining against an alias of the base "people" table. test that # the "palias" alias does *not* get sucked up into the "person_join" conversion. palias = people.alias("palias") - session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() - dilbert2 = session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() - assert dilbert is dilbert2 - - session.query(Person).filter((Engineer.engineer_name=="engineer1") & (Engineer.person_id==people.c.person_id)).first() + dilbert = session.query(Person).get(dilbert.person_id) + assert dilbert is session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() + assert dilbert is session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() + assert dilbert is session.query(Person).filter((Engineer.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)).first() + assert dilbert is session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0] - dilbert2 = session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0] - assert dilbert is dilbert2 - dilbert.engineer_name = 'hes dibert!' session.flush() session.clear() - - # save/load some managers/bosses - b = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'}) - session.save(b) + + if polymorphic_fetch == 'select': + def go(): + session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + self.assert_sql_count(testing.db, go, 2) + session.clear() + dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + def go(): + # assert that only primary table is queried for already-present-in-session + d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + self.assert_sql_count(testing.db, go, 1) + + # test standalone orphans + daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'}) + session.save(daboss) + self.assertRaises(exceptions.FlushError, session.flush) + c = session.query(Company).first() + daboss.company = c + manager_list = [e for e in c.employees if isinstance(e, Manager)] session.flush() session.clear() - c = session.query(Manager).all() - assert sets.Set([repr(x) for x in c]) == sets.Set(["Manager pointy haired boss, status AAB, manager_name manager1", "Manager jsmith, status ABA, manager_name manager2", "Boss daboss, status BBB, manager_name boss golf swing fore"]), repr([repr(x) for x in c]) - - c = session.query(Company).get(id) - for e in c.employees: - print e, e._instance_key + self.assertEquals(session.query(Manager).order_by(Manager.person_id).all(), manager_list) + c = session.query(Company).first() + session.delete(c) session.flush() + self.assertEquals(people.count().scalar(), 0) - test_roundtrip.__name__ = "test_%s%s%s%s%s" % ( - (lazy_relation and "lazy" or "eager"), - (include_base and "_inclbase" or ""), - (redefine_colprop and "_redefcol" or ""), - (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")), - (use_outer_joins and '_outerjoins' or '') - ) + test_roundtrip = _function_named( + test_roundtrip, "test_%s%s%s%s%s" % ( + (lazy_relation and "lazy" or "eager"), + (include_base and "_inclbase" or ""), + (redefine_colprop and "_redefcol" or ""), + (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")), + (use_outer_joins and '_outerjoins' or ''))) setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip) for include_base in [True, False]: @@ -330,7 +313,6 @@ for include_base in [True, False]: generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, use_outer_joins) else: generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, False) - -if __name__ == "__main__": - testbase.main() +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/polymorph2.py index a2f9c4a5f0..ed003927bb 100644 --- a/test/orm/inheritance/polymorph2.py +++ b/test/orm/inheritance/polymorph2.py @@ -1,8 +1,13 @@ -import testbase +"""this is a test suite consisting mainly of end-user test cases, testing all kinds of painful +inheritance setups for which we maintain compatibility. +""" + +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions, util from sqlalchemy.orm import * from testlib import * - +from testlib import fixtures class AttrSettable(object): def __init__(self, **kwargs): @@ -16,13 +21,13 @@ class RelationTest1(ORMTest): def define_tables(self, metadata): global people, managers - people = Table('people', metadata, + people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), Column('manager_id', Integer, ForeignKey('managers.person_id', use_alter=True, name="mpid_fq")), Column('name', String(50)), Column('type', String(30))) - managers = Table('managers', metadata, + managers = Table('managers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('status', String(30)), Column('manager_name', String(50)) @@ -31,29 +36,25 @@ class RelationTest1(ORMTest): def tearDown(self): people.update(values={people.c.manager_id:None}).execute() super(RelationTest1, self).tearDown() - + def testparentrefsdescendant(self): class Person(AttrSettable): pass class Manager(Person): pass - - mapper(Person, people, properties={ - 'manager':relation(Manager, primaryjoin=people.c.manager_id==managers.c.person_id, uselist=False) - }) - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id) - try: - compile_mappers() - except exceptions.ArgumentError, ar: - assert str(ar) == "Can't determine relation direction for relationship 'Person.manager (Manager)' - foreign key columns are present in both the parent and the child's mapped tables. Specify 'foreign_keys' argument.", str(ar) - - clear_mappers() - + + # note that up until recently (0.4.4), we had to specify "foreign_keys" here + # for this primary join. mapper(Person, people, properties={ - 'manager':relation(Manager, primaryjoin=people.c.manager_id==managers.c.person_id, foreignkey=people.c.manager_id, uselist=False, post_update=True) + 'manager':relation(Manager, primaryjoin=(people.c.manager_id == + managers.c.person_id), + uselist=False, post_update=True) }) - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id) - + mapper(Manager, managers, inherits=Person, + inherit_condition=people.c.person_id==managers.c.person_id) + + self.assertEquals(class_mapper(Person).get_property('manager').foreign_keys, set([people.c.manager_id])) + session = create_session() p = Person(name='some person') m = Manager(name='some manager') @@ -75,7 +76,10 @@ class RelationTest1(ORMTest): mapper(Person, people) mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, properties={ - 'employee':relation(Person, primaryjoin=people.c.manager_id==managers.c.person_id, foreignkey=people.c.manager_id, uselist=False, post_update=True) + 'employee':relation(Person, primaryjoin=(people.c.manager_id == + managers.c.person_id), + foreign_keys=[people.c.manager_id], + uselist=False, post_update=True) }) session = create_session() @@ -90,27 +94,27 @@ class RelationTest1(ORMTest): m = session.query(Manager).get(m.person_id) print p, m, m.employee assert m.employee is p - + class RelationTest2(ORMTest): """test self-referential relationships on polymorphic mappers""" def define_tables(self, metadata): global people, managers, data - people = Table('people', metadata, + people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), Column('name', String(50)), Column('type', String(30))) - managers = Table('managers', metadata, + managers = Table('managers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('manager_id', Integer, ForeignKey('people.person_id')), Column('status', String(30)), ) - + data = Table('data', metadata, Column('person_id', Integer, ForeignKey('managers.person_id'), primary_key=True), Column('data', String(30)) ) - + def testrelationonsubclass_j1_nodata(self): self.do_test("join1", False) def testrelationonsubclass_j2_nodata(self): @@ -123,7 +127,7 @@ class RelationTest2(ORMTest): self.do_test("join3", False) def testrelationonsubclass_j3_data(self): self.do_test("join3", True) - + def do_test(self, jointype="join1", usedata=False): class Person(AttrSettable): pass @@ -145,13 +149,13 @@ class RelationTest2(ORMTest): elif jointype == "join3": poly_union = None polymorphic_on = people.c.type - + if usedata: class Data(object): def __init__(self, data): self.data = data mapper(Data, data) - + mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=polymorphic_on) if usedata: @@ -176,7 +180,7 @@ class RelationTest2(ORMTest): m.data = Data('ms data') sess.save(m) sess.flush() - + sess.clear() p = sess.query(Person).get(p.person_id) m = sess.query(Manager).get(m.person_id) @@ -190,13 +194,13 @@ class RelationTest3(ORMTest): """test self-referential relationships on polymorphic mappers""" def define_tables(self, metadata): global people, managers, data - people = Table('people', metadata, + people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), Column('colleague_id', Integer, ForeignKey('people.person_id')), Column('name', String(50)), Column('type', String(30))) - managers = Table('managers', metadata, + managers = Table('managers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('status', String(30)), ) @@ -232,26 +236,27 @@ def generate_test(jointype="join1", usedata=False): poly_union = people.outerjoin(managers) elif jointype == "join4": poly_union=None - + if usedata: mapper(Data, data) - - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager') + if usedata: mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type, properties={ 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, remote_side=people.c.colleague_id, uselist=True), 'data':relation(Data, uselist=False) - } + } ) else: mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type, properties={ - 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, + 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, remote_side=people.c.colleague_id, uselist=True) - } + } ) + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager') + sess = create_session() p = Person(name='person1') p2 = Person(name='person2') @@ -266,7 +271,7 @@ def generate_test(jointype="join1", usedata=False): sess.save(m) sess.save(p) sess.flush() - + sess.clear() p = sess.query(Person).get(p.person_id) p2 = sess.query(Person).get(p2.person_id) @@ -279,40 +284,43 @@ def generate_test(jointype="join1", usedata=False): if usedata: assert p.data.data == 'ps data' assert m.data.data == 'ms data' - - do_test.__name__ = 'test_relationonbaseclass_%s_%s' % (jointype, data and "nodata" or "data") + + do_test = _function_named( + do_test, 'test_relationonbaseclass_%s_%s' % ( + jointype, data and "nodata" or "data")) return do_test for jointype in ["join1", "join2", "join3", "join4"]: for data in (True, False): func = generate_test(jointype, data) setattr(RelationTest3, func.__name__, func) - - + + class RelationTest4(ORMTest): def define_tables(self, metadata): global people, engineers, managers, cars - people = Table('people', metadata, + people = Table('people', metadata, Column('person_id', Integer, primary_key=True), Column('name', String(50))) - engineers = Table('engineers', metadata, + engineers = Table('engineers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('status', String(30))) - managers = Table('managers', metadata, + managers = Table('managers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('longer_status', String(70))) - cars = Table('cars', metadata, + cars = Table('cars', metadata, Column('car_id', Integer, primary_key=True), Column('owner', Integer, ForeignKey('people.person_id'))) - + def testmanytoonepolymorphic(self): """in this test, the polymorphic union is between two subclasses, but does not include the base table by itself in the union. however, the primaryjoin condition is going to be against the base table, and its a many-to-one relationship (unlike the test in polymorph.py) so the column in the base table is explicit. Can the ClauseAdapter figure out how to alias the primaryjoin to the polymorphic union ?""" + # class definitions class Person(object): def __init__(self, **kwargs): @@ -333,30 +341,17 @@ class RelationTest4(ORMTest): def __repr__(self): return "Car number %d" % self.car_id - # create a union that represents both types of joins. + # create a union that represents both types of joins. employee_join = polymorphic_union( { 'engineer':people.join(engineers), 'manager':people.join(managers), }, "type", 'employee_join') - + person_mapper = mapper(Person, people, select_table=employee_join,polymorphic_on=employee_join.c.type, polymorphic_identity='person') engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') car_mapper = mapper(Car, cars, properties= {'employee':relation(person_mapper)}) - - print class_mapper(Person).primary_key - print person_mapper.get_select_mapper().primary_key - - # so the primaryjoin is "people.c.person_id==cars.c.owner". the "lazy" clause will be - # "people.c.person_id=?". the employee_join is two selects union'ed together, one of which - # will contain employee.c.person_id the other contains manager.c.person_id. people.c.person_id is not explicitly in - # either column clause in this case. we can modify polymorphic_union to always put the "base" column in which would fix this, - # but im not sure if that really fixes the issue in all cases and its too far from the problem. - # instead, when the primaryjoin is adapted to point to the polymorphic union and is targeting employee_join.c.person_id, - # it has to use not just straight column correspondence but also "keys_ok=True", meaning it will link up to any column - # with the name "person_id", as opposed to columns that descend directly from people.c.person_id. polymorphic unions - # require the cols all match up on column name which then determine the top selectable names, so matching by name is OK. session = create_session() @@ -371,7 +366,7 @@ class RelationTest4(ORMTest): engineer4 = session.query(Engineer).filter(Engineer.name=="E4").first() manager3 = session.query(Manager).filter(Manager.name=="M3").first() - + car1 = Car(employee=engineer4) session.save(car1) car2 = Car(employee=manager3) @@ -379,7 +374,12 @@ class RelationTest4(ORMTest): session.flush() session.clear() - + + def go(): + testcar = session.query(Car).options(eagerload('employee')).get(car1.car_id) + assert str(testcar.employee) == "Engineer E4, status X" + self.assert_sql_count(testing.db, go, 1) + print "----------------------------" car1 = session.query(Car).get(car1.car_id) print "----------------------------" @@ -397,8 +397,11 @@ class RelationTest4(ORMTest): session.clear() print "-----------------------------------------------------------------" # and now for the lightning round, eager ! - car1 = session.query(Car).options(eagerload('employee')).get(car1.car_id) - assert str(car1.employee) == "Engineer E4, status X" + + def go(): + testcar = session.query(Car).options(eagerload('employee')).get(car1.car_id) + assert str(testcar.employee) == "Engineer E4, status X" + self.assert_sql_count(testing.db, go, 1) session.clear() s = session.query(Car) @@ -408,23 +411,23 @@ class RelationTest4(ORMTest): class RelationTest5(ORMTest): def define_tables(self, metadata): global people, engineers, managers, cars - people = Table('people', metadata, + people = Table('people', metadata, Column('person_id', Integer, primary_key=True), Column('name', String(50)), Column('type', String(50))) - engineers = Table('engineers', metadata, + engineers = Table('engineers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('status', String(30))) - managers = Table('managers', metadata, + managers = Table('managers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('longer_status', String(70))) - cars = Table('cars', metadata, + cars = Table('cars', metadata, Column('car_id', Integer, primary_key=True), Column('owner', Integer, ForeignKey('people.person_id'))) - + def testeagerempty(self): """an easy one...test parent object with child relation to an inheriting mapper, using eager loads, works when there are no child objects present""" @@ -460,8 +463,8 @@ class RelationTest5(ORMTest): sess.save(car2) sess.flush() sess.clear() - - carlist = sess.query(Car).select() + + carlist = sess.query(Car).all() assert carlist[0].manager is None assert carlist[1].manager.person_id == car2.manager.person_id @@ -469,12 +472,12 @@ class RelationTest6(ORMTest): """test self-referential relationships on a single joined-table inheritance mapper""" def define_tables(self, metadata): global people, managers, data - people = Table('people', metadata, + people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), Column('name', String(50)), ) - managers = Table('managers', metadata, + managers = Table('managers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('colleague_id', Integer, ForeignKey('managers.person_id')), Column('status', String(30)), @@ -566,7 +569,7 @@ class RelationTest7(ORMTest): employee_join = polymorphic_union( { 'engineer':people.join(engineers), - 'manager':people.join(managers), + 'manager':people.join(managers), }, "type", 'employee_join') car_join = polymorphic_union( @@ -582,7 +585,7 @@ class RelationTest7(ORMTest): offroad_car_mapper = mapper(Offraod_Car, offroad_cars, inherits=car_mapper, polymorphic_identity='offroad') person_mapper = mapper(Person, people, select_table=employee_join,polymorphic_on=employee_join.c.type, - polymorphic_identity='person', + polymorphic_identity='person', properties={ 'car':relation(car_mapper) }) @@ -603,11 +606,13 @@ class RelationTest7(ORMTest): session.flush() session.clear() - r = session.query(Person).select() + r = session.query(Person).all() for p in r: assert p.car_id == p.car.car_id - -class GenerativeTest(AssertMixin): + + + +class GenerativeTest(TestBase, AssertsExecutionResults): def setUpAll(self): # cars---owned by--- people (abstract) --- has a --- status # | ^ ^ | @@ -617,26 +622,26 @@ class GenerativeTest(AssertMixin): # +--------------------------------------- has a ------+ global metadata, status, people, engineers, managers, cars - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) # table definitions - status = Table('status', metadata, + status = Table('status', metadata, Column('status_id', Integer, primary_key=True), Column('name', String(20))) - people = Table('people', metadata, + people = Table('people', metadata, Column('person_id', Integer, primary_key=True), Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), Column('name', String(50))) - engineers = Table('engineers', metadata, + engineers = Table('engineers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('field', String(30))) - managers = Table('managers', metadata, + managers = Table('managers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), Column('category', String(70))) - cars = Table('cars', metadata, + cars = Table('cars', metadata, Column('car_id', Integer, primary_key=True), Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), Column('owner', Integer, ForeignKey('people.person_id'), nullable=False)) @@ -649,7 +654,7 @@ class GenerativeTest(AssertMixin): clear_mappers() for t in metadata.table_iterator(reverse=True): t.delete().execute() - + def testjointo(self): # class definitions class PersistentObject(object): @@ -672,7 +677,7 @@ class GenerativeTest(AssertMixin): def __repr__(self): return "Car number %d" % self.car_id - # create a union that represents both types of joins. + # create a union that represents both types of joins. employee_join = polymorphic_union( { 'engineer':people.join(engineers), @@ -680,8 +685,8 @@ class GenerativeTest(AssertMixin): }, "type", 'employee_join') status_mapper = mapper(Status, status) - person_mapper = mapper(Person, people, - select_table=employee_join,polymorphic_on=employee_join.c.type, + person_mapper = mapper(Person, people, + select_table=employee_join,polymorphic_on=employee_join.c.type, polymorphic_identity='person', properties={'status':relation(status_mapper)}) engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') @@ -697,7 +702,7 @@ class GenerativeTest(AssertMixin): session.flush() # TODO: we haven't created assertions for all the data combinations created here - + # creating 5 managers named from M1 to M5 and 5 engineers named from E1 to E5 # M4, M5, E4 and E5 are dead for i in range(1,5): @@ -711,7 +716,7 @@ class GenerativeTest(AssertMixin): session.flush() # get E4 - engineer4 = session.query(engineer_mapper).get_by(name="E4") + engineer4 = session.query(engineer_mapper).filter_by(name="E4").one() # create 2 cars for E4, one active and one dead car1 = Car(employee=engineer4,status=active) @@ -724,15 +729,15 @@ class GenerativeTest(AssertMixin): for x in range(0, 2): r = session.query(Person).filter(people.c.name.like('%2')).join('status').filter_by(name="active") assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]" - r = session.query(Engineer).join('status').filter(people.c.name.in_('E2', 'E3', 'E4', 'M4', 'M2', 'M1') & (status.c.name=="active")) + r = session.query(Engineer).join('status').filter(people.c.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active")) assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]" - # this test embeds the original polymorphic union (employee_join) fully - # into the WHERE criterion, using a correlated select. ticket #577 tracks - # that Query's adaptation of the WHERE clause does not dig into the + # this test embeds the original polymorphic union (employee_join) fully + # into the WHERE criterion, using a correlated select. ticket #577 tracks + # that Query's adaptation of the WHERE clause does not dig into the # mapped selectable itself, which permanently breaks the mapped selectable. r = session.query(Person).filter(exists([Car.c.owner], Car.c.owner==employee_join.c.person_id)) assert str(list(r)) == "[Engineer E4, field X, status Status dead]" - + class MultiLevelTest(ORMTest): def define_tables(self, metadata): global table_Employee, table_Engineer, table_Manager @@ -772,7 +777,7 @@ class MultiLevelTest(ORMTest): # 'Engineer': table_Employee.join(table_Engineer).select(table_Employee.c.atype == 'Engineer'), # 'Employee': table_Employee.select( table_Employee.c.atype == 'Employee'), # }, None, 'pu_employee', ) - + mapper_Employee = mapper( Employee, table_Employee, polymorphic_identity= 'Employee', polymorphic_on= pu_Employee.c.atype, @@ -806,9 +811,9 @@ class MultiLevelTest(ORMTest): session.save(b) session.save(c) session.flush() - assert set(session.query(Employee).select()) == set([a,b,c]) - assert set(session.query( Engineer).select()) == set([b,c]) - assert session.query( Manager).select() == [c] + assert set(session.query(Employee).all()) == set([a,b,c]) + assert set(session.query( Engineer).all()) == set([b,c]) + assert session.query( Manager).all() == [c] class ManyToManyPolyTest(ORMTest): def define_tables(self, metadata): @@ -832,9 +837,9 @@ class ManyToManyPolyTest(ORMTest): 'collection', metadata, Column('id', Integer, primary_key=True), Column('name', Unicode(255))) - + def test_pjoin_compile(self): - """test that remote_side columns in the secondary join table arent attempted to be + """test that remote_side columns in the secondary join table arent attempted to be matched to the target polymorphic selectable""" class BaseItem(object): pass class Item(BaseItem): pass @@ -857,13 +862,13 @@ class ManyToManyPolyTest(ORMTest): polymorphic_identity='Item') mapper(Collection, collection_table) - + class_mapper(BaseItem) class CustomPKTest(ORMTest): def define_tables(self, metadata): global t1, t2 - t1 = Table('t1', metadata, + t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', String(30), nullable=False), Column('data', String(30))) @@ -874,23 +879,19 @@ class CustomPKTest(ORMTest): def test_custompk(self): """test that the primary_key attribute is propigated to the polymorphic mapper""" - + class T1(object):pass class T2(T1):pass - + # create a polymorphic union with the select against the base table first. - # with the join being second, the alias of the union will + # with the join being second, the alias of the union will # pick up two "primary key" columns. technically the alias should have a # 2-col pk in any case but the leading select has a NULL for the "t2id" column d = util.OrderedDict() d['t1'] = t1.select(t1.c.type=='t1') d['t2'] = t1.join(t2) pjoin = polymorphic_union(d, None, 'pjoin') - - #print pjoin.original.primary_key - #print pjoin.primary_key - assert len(pjoin.primary_key) == 2 - + mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', select_table=pjoin, primary_key=[pjoin.c.id]) mapper(T2, t2, inherits=T1, polymorphic_identity='t2') print [str(c) for c in class_mapper(T1).primary_key] @@ -901,24 +902,24 @@ class CustomPKTest(ORMTest): sess.save(ot2) sess.flush() sess.clear() - + # query using get(), using only one value. this requires the select_table mapper # has the same single-col primary key. assert sess.query(T1).get(ot1.id).id == ot1.id - + ot1 = sess.query(T1).get(ot1.id) ot1.data = 'hi' sess.flush() def test_pk_collapses(self): - """test that a composite primary key attribute formed by a join is "collapsed" into its + """test that a composite primary key attribute formed by a join is "collapsed" into its minimal columns""" class T1(object):pass class T2(T1):pass # create a polymorphic union with the select against the base table first. - # with the join being second, the alias of the union will + # with the join being second, the alias of the union will # pick up two "primary key" columns. technically the alias should have a # 2-col pk in any case but the leading select has a NULL for the "t2id" column d = util.OrderedDict() @@ -926,15 +927,10 @@ class CustomPKTest(ORMTest): d['t2'] = t1.join(t2) pjoin = polymorphic_union(d, None, 'pjoin') - #print pjoin.original.primary_key - #print pjoin.primary_key - assert len(pjoin.primary_key) == 2 - mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', select_table=pjoin) mapper(T2, t2, inherits=T1, polymorphic_identity='t2') assert len(class_mapper(T1).primary_key) == 1 - assert len(class_mapper(T1).get_select_mapper().compile().primary_key) == 1 - + print [str(c) for c in class_mapper(T1).primary_key] ot1 = T1() ot2 = T2() @@ -951,7 +947,116 @@ class CustomPKTest(ORMTest): ot1 = sess.query(T1).get(ot1.id) ot1.data = 'hi' sess.flush() + +class InheritingEagerTest(ORMTest): + def define_tables(self, metadata): + global people, employees, tags, peopleTags + + people = Table('people', metadata, + Column('id', Integer, primary_key=True), + Column('_type', String(30), nullable=False), + ) + + + employees = Table('employees', metadata, + Column('id', Integer, ForeignKey('people.id'),primary_key=True), + ) + + tags = Table('tags', metadata, + Column('id', Integer, primary_key=True), + Column('label', String(50), nullable=False), + ) + + peopleTags = Table('peopleTags', metadata, + Column('person_id', Integer,ForeignKey('people.id')), + Column('tag_id', Integer,ForeignKey('tags.id')), + ) + + def test_basic(self): + """test that Query uses the full set of mapper._eager_loaders when generating SQL""" + + class Person(fixtures.Base): + pass + + class Employee(Person): + def __init__(self, name='bob'): + self.name = name + + class Tag(fixtures.Base): + def __init__(self, label): + self.label = label + + mapper(Person, people, polymorphic_on=people.c._type,polymorphic_identity='person', properties={ + 'tags': relation(Tag, secondary=peopleTags,backref='people', lazy=False) + }) + mapper(Employee, employees, inherits=Person,polymorphic_identity='employee') + mapper(Tag, tags) + + session = create_session() + + bob = Employee() + session.save(bob) + + tag = Tag('crazy') + bob.tags.append(tag) + + tag = Tag('funny') + bob.tags.append(tag) + session.flush() + + session.clear() + # query from Employee with limit, query needs to apply eager limiting subquery + instance = session.query(Employee).filter_by(id=1).limit(1).first() + assert len(instance.tags) == 2 + +class MissingPolymorphicOnTest(ORMTest): + def define_tables(self, metadata): + global tablea, tableb, tablec, tabled + tablea = Table('tablea', metadata, + Column('id', Integer, primary_key=True), + Column('adata', String(50)), + ) + tableb = Table('tableb', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('tablea.id')), + Column('data', String(50)), + ) + tablec = Table('tablec', metadata, + Column('id', Integer, ForeignKey('tablea.id'), primary_key=True), + Column('cdata', String(50)), + ) + tabled = Table('tabled', metadata, + Column('id', Integer, ForeignKey('tablec.id'), primary_key=True), + Column('ddata', String(50)), + ) + + def test_polyon_col_setsup(self): + class A(fixtures.Base): + pass + class B(fixtures.Base): + pass + class C(A): + pass + class D(C): + pass + + poly_select = select([tablea, tableb.c.data.label('discriminator')], from_obj=tablea.join(tableb)).alias('poly') -if __name__ == "__main__": - testbase.main() + mapper(B, tableb) + mapper(A, tablea, select_table=poly_select, polymorphic_on=poly_select.c.discriminator, properties={ + 'b':relation(B, uselist=False) + }) + mapper(C, tablec, inherits=A,polymorphic_identity='c') + mapper(D, tabled, inherits=C, polymorphic_identity='d') + + c = C(cdata='c1', adata='a1', b=B(data='c')) + d = D(cdata='c2', adata='a2', ddata='d2', b=B(data='d')) + sess = create_session() + sess.save(c) + sess.save(d) + sess.flush() + sess.clear() + self.assertEquals(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')]) +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/inheritance/productspec.py b/test/orm/inheritance/productspec.py index 2459cd36e1..54810c31fa 100644 --- a/test/orm/inheritance/productspec.py +++ b/test/orm/inheritance/productspec.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() from datetime import datetime from sqlalchemy import * from sqlalchemy.orm import * @@ -26,7 +26,7 @@ class InheritTest(ORMTest): nullable=True), Column('quantity', Float, default=1.), ) - + documents_table = Table('documents', metadata, Column('document_id', Integer, primary_key=True), Column('document_type', String(128)), @@ -38,7 +38,7 @@ class InheritTest(ORMTest): Column('data', Binary), Column('size', Integer, default=0), ) - + class Product(object): def __init__(self, name, mark=''): self.name = name @@ -73,8 +73,8 @@ class InheritTest(ORMTest): self.data = data def __repr__(self): return '<%s %s>' % (self.__class__.__name__, self.name) - - class RasterDocument(Document): + + class RasterDocument(Document): pass def testone(self): @@ -91,12 +91,12 @@ class InheritTest(ORMTest): specification_mapper = mapper(SpecLine, specification_table, properties=dict( master=relation(Assembly, - foreignkey=specification_table.c.master_id, + foreign_keys=[specification_table.c.master_id], primaryjoin=specification_table.c.master_id==products_table.c.product_id, lazy=True, backref=backref('specification', primaryjoin=specification_table.c.master_id==products_table.c.product_id), uselist=False), - slave=relation(Product, - foreignkey=specification_table.c.slave_id, + slave=relation(Product, + foreign_keys=[specification_table.c.slave_id], primaryjoin=specification_table.c.slave_id==products_table.c.product_id, lazy=True, uselist=False), quantity=specification_table.c.quantity, @@ -118,7 +118,7 @@ class InheritTest(ORMTest): session.flush() session.clear() - a1 = session.query(Product).get_by(name='a1') + a1 = session.query(Product).filter_by(name='a1').one() new = repr(a1) print orig print new @@ -134,8 +134,8 @@ class InheritTest(ORMTest): specification_mapper = mapper(SpecLine, specification_table, properties=dict( - slave=relation(Product, - foreignkey=specification_table.c.slave_id, + slave=relation(Product, + foreign_keys=[specification_table.c.slave_id], primaryjoin=specification_table.c.slave_id==products_table.c.product_id, lazy=True, uselist=False), ) @@ -150,7 +150,7 @@ class InheritTest(ORMTest): orig = repr([s, s2]) session.flush() session.clear() - new = repr(session.query(SpecLine).select()) + new = repr(session.query(SpecLine).all()) print orig print new assert orig == new == '[>, >]' @@ -167,12 +167,12 @@ class InheritTest(ORMTest): specification_mapper = mapper(SpecLine, specification_table, properties=dict( master=relation(Assembly, lazy=False, uselist=False, - foreignkey=specification_table.c.master_id, + foreign_keys=[specification_table.c.master_id], primaryjoin=specification_table.c.master_id==products_table.c.product_id, backref=backref('specification', primaryjoin=specification_table.c.master_id==products_table.c.product_id, cascade="all, delete-orphan"), ), slave=relation(Product, lazy=False, uselist=False, - foreignkey=specification_table.c.slave_id, + foreign_keys=[specification_table.c.slave_id], primaryjoin=specification_table.c.slave_id==products_table.c.product_id, ), quantity=specification_table.c.quantity, @@ -202,7 +202,7 @@ class InheritTest(ORMTest): session.flush() session.clear() - a1 = session.query(Product).get_by(name='a1') + a1 = session.query(Product).filter_by(name='a1').one() new = repr(a1) print orig print new @@ -240,19 +240,18 @@ class InheritTest(ORMTest): session.flush() session.clear() - a1 = session.query(Product).get_by(name='a1') + a1 = session.query(Product).filter_by(name='a1').one() new = repr(a1) print orig print new assert orig == new == ' specification=None documents=[]' del a1.documents[0] - session.save(a1) session.flush() session.clear() - a1 = session.query(Product).get_by(name='a1') - assert len(session.query(Document).select()) == 0 + a1 = session.query(Product).filter_by(name='a1').one() + assert len(session.query(Document).all()) == 0 def testfive(self): """tests the late compilation of mappers""" @@ -260,24 +259,18 @@ class InheritTest(ORMTest): specification_mapper = mapper(SpecLine, specification_table, properties=dict( master=relation(Assembly, lazy=False, uselist=False, - foreignkey=specification_table.c.master_id, + foreign_keys=[specification_table.c.master_id], primaryjoin=specification_table.c.master_id==products_table.c.product_id, backref=backref('specification', primaryjoin=specification_table.c.master_id==products_table.c.product_id), ), slave=relation(Product, lazy=False, uselist=False, - foreignkey=specification_table.c.slave_id, + foreign_keys=[specification_table.c.slave_id], primaryjoin=specification_table.c.slave_id==products_table.c.product_id, ), quantity=specification_table.c.quantity, ) ) - detail_mapper = mapper(Detail, inherits=Product, - polymorphic_identity='detail') - - raster_document_mapper = mapper(RasterDocument, inherits=Document, - polymorphic_identity='raster_document') - product_mapper = mapper(Product, products_table, polymorphic_on=products_table.c.product_type, polymorphic_identity='product', properties={ @@ -285,8 +278,8 @@ class InheritTest(ORMTest): backref='product', cascade='all, delete-orphan'), }) - assembly_mapper = mapper(Assembly, inherits=Product, - polymorphic_identity='assembly') + detail_mapper = mapper(Detail, inherits=Product, + polymorphic_identity='detail') document_mapper = mapper(Document, documents_table, polymorphic_on=documents_table.c.document_type, @@ -297,6 +290,12 @@ class InheritTest(ORMTest): ), ) + raster_document_mapper = mapper(RasterDocument, inherits=Document, + polymorphic_identity='raster_document') + + assembly_mapper = mapper(Assembly, inherits=Product, + polymorphic_identity='assembly') + session = create_session() a1 = Assembly(name='a1') @@ -308,11 +307,11 @@ class InheritTest(ORMTest): session.flush() session.clear() - a1 = session.query(Product).get_by(name='a1') + a1 = session.query(Product).filter_by(name='a1').one() new = repr(a1) print orig print new assert orig == new == ' specification=[>] documents=[, ]' - -if __name__ == "__main__": - testbase.main() + +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py new file mode 100644 index 0000000000..34ead1622c --- /dev/null +++ b/test/orm/inheritance/query.py @@ -0,0 +1,575 @@ +"""tests the Query object's ability to work with polymorphic selectables +and inheriting mappers.""" + +# TODO: under construction ! + +import testenv; testenv.configure_for_tests() +import sets +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy import exceptions +from testlib import * +from testlib import fixtures + +class Company(fixtures.Base): + pass + +class Person(fixtures.Base): + pass +class Engineer(Person): + pass +class Manager(Person): + pass +class Boss(Manager): + pass + +class Machine(fixtures.Base): + pass + +class Paperwork(fixtures.Base): + pass + +def make_test(select_type): + class PolymorphicQueryTest(ORMTest): + keep_data = True + keep_mappers = True + + def define_tables(self, metadata): + global companies, people, engineers, managers, boss, paperwork, machines + + companies = Table('companies', metadata, + Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key=True), + Column('name', String(50))) + + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('company_id', Integer, ForeignKey('companies.company_id')), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('engineer_name', String(50)), + Column('primary_language', String(50)), + ) + + machines = Table('machines', metadata, + Column('machine_id', Integer, primary_key=True), + Column('name', String(50)), + Column('engineer_id', Integer, ForeignKey('engineers.person_id'))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('manager_name', String(50)) + ) + + boss = Table('boss', metadata, + Column('boss_id', Integer, ForeignKey('managers.person_id'), primary_key=True), + Column('golf_swing', String(30)), + ) + + paperwork = Table('paperwork', metadata, + Column('paperwork_id', Integer, primary_key=True), + Column('description', String(50)), + Column('person_id', Integer, ForeignKey('people.person_id'))) + + clear_mappers() + + mapper(Company, companies, properties={ + 'employees':relation(Person, order_by=people.c.person_id) + }) + + mapper(Machine, machines) + + if select_type == '': + person_join = manager_join = None + person_with_polymorphic = None + manager_with_polymorphic = None + elif select_type == 'Polymorphic': + person_join = manager_join = None + person_with_polymorphic = '*' + manager_with_polymorphic = '*' + elif select_type == 'Unions': + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, None, 'pjoin') + + manager_join = people.join(managers).outerjoin(boss) + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ('*', manager_join) + elif select_type == 'AliasedJoins': + person_join = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin') + manager_join = people.join(managers).outerjoin(boss).select(use_labels=True).alias('mjoin') + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ('*', manager_join) + elif select_type == 'Joins': + person_join = people.outerjoin(engineers).outerjoin(managers) + manager_join = people.join(managers).outerjoin(boss) + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ('*', manager_join) + + + # testing a order_by here as well; the surrogate mapper has to adapt it + mapper(Person, people, + with_polymorphic=person_with_polymorphic, + polymorphic_on=people.c.type, polymorphic_identity='person', order_by=people.c.person_id, + properties={ + 'paperwork':relation(Paperwork) + }) + mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer', properties={ + 'machines':relation(Machine) + }) + mapper(Manager, managers, with_polymorphic=manager_with_polymorphic, + inherits=Person, polymorphic_identity='manager') + mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss') + mapper(Paperwork, paperwork) + + + def insert_data(self): + global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2 + + c1 = Company(name="MegaCorp, Inc.") + c2 = Company(name="Elbonia, Inc.") + e1 = Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", paperwork=[ + Paperwork(description="tps report #1"), + Paperwork(description="tps report #2") + ], machines=[ + Machine(name='IBM ThinkPad'), + Machine(name='IPhone'), + ]) + e2 = Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer", paperwork=[ + Paperwork(description="tps report #3"), + Paperwork(description="tps report #4") + ], machines=[ + Machine(name="Commodore 64") + ]) + b1 = Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss", paperwork=[ + Paperwork(description="review #1"), + ]) + m1 = Manager(name="dogbert", manager_name="dogbert", status="regular manager", paperwork=[ + Paperwork(description="review #2"), + Paperwork(description="review #3") + ]) + c1.employees = [e1, e2, b1, m1] + + e3 = Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer", paperwork=[ + Paperwork(description='elbonian missive #3') + ], machines=[ + Machine(name="Commodore 64"), + Machine(name="IBM 3270") + ]) + + c2.employees = [e3] + sess = create_session() + sess.save(c1) + sess.save(c2) + sess.flush() + sess.clear() + + all_employees = [e1, e2, b1, m1, e3] + c1_employees = [e1, e2, b1, m1] + c2_employees = [e3] + + def test_loads_at_once(self): + """test that all objects load from the full query, when with_polymorphic is used""" + + sess = create_session() + def go(): + self.assertEquals(sess.query(Person).all(), all_employees) + self.assert_sql_count(testing.db, go, {'':14, 'Polymorphic':9}.get(select_type, 10)) + + def test_primary_eager_aliasing(self): + sess = create_session() + def go(): + self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3].all(), all_employees[1:3]) + self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4)) + + sess = create_session() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3].all(), all_employees[1:3]) + self.assert_sql_count(testing.db, go, 3) + + + def test_get(self): + sess = create_session() + + # for all mappers, ensure the primary key has been calculated as just the "person_id" + # column + self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert")) + self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert")) + self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss")) + + def test_filter_on_subclass(self): + sess = create_session() + self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert")) + + self.assertEquals(sess.query(Engineer).first(), Engineer(name="dilbert")) + + self.assertEquals(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert")) + + self.assertEquals(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert")) + + self.assertEquals(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + + self.assertEquals(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + + def test_join_from_polymorphic(self): + sess = create_session() + + for aliased in (True, False): + self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + + self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + + self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) + + self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + + def test_join_from_with_polymorphic(self): + sess = create_session() + + for aliased in (True, False): + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + + def test_join_to_polymorphic(self): + sess = create_session() + self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) + + self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) + + def test_polymorphic_any(self): + sess = create_session() + + self.assertEquals( + sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), + c2 + ) + + self.assertEquals( + sess.query(Company).filter(Company.employees.of_type(Boss).any(Boss.golf_swing=='fore')).one(), + c1 + ) + self.assertEquals( + sess.query(Company).filter(Company.employees.of_type(Boss).any(Manager.manager_name=='pointy')).one(), + c1 + ) + + if select_type != '': + self.assertEquals( + sess.query(Person).filter(Engineer.machines.any(Machine.name=="Commodore 64")).all(), [e2, e3] + ) + + self.assertEquals( + sess.query(Person).filter(Person.paperwork.any(Paperwork.description=="review #2")).all(), [m1] + ) + + self.assertEquals( + sess.query(Company).filter(Company.employees.of_type(Engineer).any(and_(Engineer.primary_language=='cobol'))).one(), + c2 + ) + + + def test_expire(self): + """test that individual column refresh doesn't get tripped up by the select_table mapper""" + + sess = create_session() + m1 = sess.query(Manager).filter(Manager.name=='dogbert').one() + sess.expire(m1) + assert m1.status == 'regular manager' + + m2 = sess.query(Manager).filter(Manager.name=='pointy haired boss').one() + sess.expire(m2, ['manager_name', 'golf_swing']) + assert m2.golf_swing=='fore' + + def test_with_polymorphic(self): + + sess = create_session() + + # compare to entities without related collections to prevent additional lazy SQL from firing on + # loaded entities + emps_without_relations = [ + Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer"), + Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"), + Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"), + Manager(name="dogbert", manager_name="dogbert", status="regular manager"), + Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") + ] + + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 3) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 3) + + sess.clear() + def go(): + # limit the polymorphic join down to just "Person", overriding select_table + self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 6) + + def test_relation_to_polymorphic(self): + assert_result = [ + Company(name="MegaCorp, Inc.", employees=[ + Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")]), + Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"), + Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"), + Manager(name="dogbert", manager_name="dogbert", status="regular manager"), + ]), + Company(name="Elbonia, Inc.", employees=[ + Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") + ]) + ] + + sess = create_session() + def go(): + # test load Companies with lazy load to 'employees' + self.assertEquals(sess.query(Company).all(), assert_result) + self.assert_sql_count(testing.db, go, {'':9, 'Polymorphic':4}.get(select_type, 5)) + + sess = create_session() + def go(): + # currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer). eagerloader doesn't + # pick up on the "of_type()" as of yet. + self.assertEquals(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result) + + # in the case of select_type='', the eagerload doesn't take in this case; + # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines" + self.assert_sql_count(testing.db, go, {'':7, 'Polymorphic':1}.get(select_type, 2)) + + def test_eagerload_on_subclass(self): + sess = create_session() + def go(): + # test load People with eagerload to engineers + machines + self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), + [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])] + ) + self.assert_sql_count(testing.db, go, 1) + + def test_join_to_subclass(self): + sess = create_session() + + if select_type == '': + self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) + self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) + self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3]) + self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2]) + self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + else: + self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) + self.assertEquals(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1]) + self.assertEquals(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) + self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2]) + self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + + # non-polymorphic + self.assertEquals(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) + self.assertEquals(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + + # here's the new way + self.assertEquals(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1]) + self.assertEquals(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + + def test_join_through_polymorphic(self): + + sess = create_session() + + for aliased in (True, False): + self.assertEquals( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + self.assertEquals( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + + self.assertEquals( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + self.assertEquals( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + + self.assertEquals( + sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ + join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + self.assertEquals( + sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ + join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + + def test_filter_on_baseclass(self): + sess = create_session() + + self.assertEquals(sess.query(Person).all(), all_employees) + + self.assertEquals(sess.query(Person).first(), all_employees[0]) + + self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2) + + PolymorphicQueryTest.__name__ = "Polymorphic%sTest" % select_type + return PolymorphicQueryTest + +for select_type in ('', 'Polymorphic', 'Unions', 'AliasedJoins', 'Joins'): + testclass = make_test(select_type) + exec("%s = testclass" % testclass.__name__) + +del testclass + +class SelfReferentialTest(ORMTest): + keep_mappers = True + + def define_tables(self, metadata): + global people, engineers + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('primary_language', String(50)), + Column('reports_to_id', Integer, ForeignKey('people.person_id')) + ) + + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') + mapper(Engineer, engineers, inherits=Person, + inherit_condition=engineers.c.person_id==people.c.person_id, + polymorphic_identity='engineer', properties={ + 'reports_to':relation(Person, primaryjoin=people.c.person_id==engineers.c.reports_to_id) + }) + + def test_has(self): + + p1 = Person(name='dogbert') + e1 = Engineer(name='dilbert', primary_language='java', reports_to=p1) + sess = create_session() + sess.save(p1) + sess.save(e1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert')) + + def test_join(self): + p1 = Person(name='dogbert') + e1 = Engineer(name='dilbert', primary_language='java', reports_to=p1) + sess = create_session() + sess.save(p1) + sess.save(e1) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), Engineer(name='dilbert')) + + def test_noalias_raises(self): + sess = create_session() + def go(): + sess.query(Engineer).join('reports_to') + self.assertRaises(exceptions.InvalidRequestError, go) + +class M2MFilterTest(ORMTest): + keep_mappers = True + keep_data = True + + def define_tables(self, metadata): + global people, engineers, Organization + + organizations = Table('organizations', metadata, + Column('id', Integer, Sequence('org_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + ) + engineers_to_org = Table('engineers_org', metadata, + Column('org_id', Integer, ForeignKey('organizations.id')), + Column('engineer_id', Integer, ForeignKey('engineers.person_id')), + ) + + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('primary_language', String(50)), + ) + + class Organization(fixtures.Base): + pass + + mapper(Organization, organizations, properties={ + 'engineers':relation(Engineer, secondary=engineers_to_org, backref='organizations') + }) + + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') + mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') + + def insert_data(self): + e1 = Engineer(name='e1') + e2 = Engineer(name='e2') + e3 = Engineer(name='e3') + e4 = Engineer(name='e4') + org1 = Organization(name='org1', engineers=[e1, e2]) + org2 = Organization(name='org2', engineers=[e3, e4]) + + sess = create_session() + sess.save(org1) + sess.save(org2) + sess.flush() + + def test_not_contains(self): + sess = create_session() + + e1 = sess.query(Person).filter(Engineer.name=='e1').one() + + # this works + self.assertEquals(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')]) + + # this had a bug + self.assertEquals(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')]) + + def test_any(self): + sess = create_session() + self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')]) + self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')]) + +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/inheritance/selects.py b/test/orm/inheritance/selects.py new file mode 100644 index 0000000000..b3a343e387 --- /dev/null +++ b/test/orm/inheritance/selects.py @@ -0,0 +1,51 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * +from testlib.fixtures import Base + + +class InheritingSelectablesTest(ORMTest): + def define_tables(self, metadata): + global foo, bar, baz + foo = Table('foo', metadata, + Column('a', String(30), primary_key=1), + Column('b', String(30), nullable=0)) + + bar = foo.select(foo.c.b == 'bar').alias('bar') + baz = foo.select(foo.c.b == 'baz').alias('baz') + + def test_load(self): + # TODO: add persistence test also + testing.db.execute(foo.insert(), a='not bar', b='baz') + testing.db.execute(foo.insert(), a='also not bar', b='baz') + testing.db.execute(foo.insert(), a='i am bar', b='bar') + testing.db.execute(foo.insert(), a='also bar', b='bar') + + class Foo(Base): pass + class Bar(Foo): pass + class Baz(Foo): pass + + mapper(Foo, foo, polymorphic_on=foo.c.b) + + mapper(Baz, baz, + select_table=foo.join(baz, foo.c.b=='baz').alias('baz'), + inherits=Foo, + inherit_condition=(foo.c.a==baz.c.a), + inherit_foreign_keys=[baz.c.a], + polymorphic_identity='baz') + + mapper(Bar, bar, + select_table=foo.join(bar, foo.c.b=='bar').alias('bar'), + inherits=Foo, + inherit_condition=(foo.c.a==bar.c.a), + inherit_foreign_keys=[bar.c.a], + polymorphic_identity='bar') + + s = sessionmaker(bind=testing.db)() + + assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all() + assert [Bar(), Bar()] == s.query(Bar).all() + +if __name__ == '__main__': + testenv.main() diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/single.py index 68fe821af0..81223cc02e 100644 --- a/test/orm/inheritance/single.py +++ b/test/orm/inheritance/single.py @@ -1,14 +1,14 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * -class SingleInheritanceTest(AssertMixin): +class SingleInheritanceTest(TestBase, AssertsExecutionResults): def setUpAll(self): - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) global employees_table - employees_table = Table('employees', metadata, + employees_table = Table('employees', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), Column('manager_data', String(50)), @@ -57,10 +57,10 @@ class SingleInheritanceTest(AssertMixin): session.save(e2) session.flush() - assert session.query(Employee).select() == [m1, e1, e2] - assert session.query(Engineer).select() == [e1, e2] - assert session.query(Manager).select() == [m1] - assert session.query(JuniorEngineer).select() == [e2] - + assert session.query(Employee).all() == [m1, e1, e2] + assert session.query(Engineer).all() == [e1, e2] + assert session.query(Manager).all() == [m1] + assert session.query(JuniorEngineer).all() == [e2] + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py index 6684c62881..55d79fd32b 100644 --- a/test/orm/lazy_relations.py +++ b/test/orm/lazy_relations.py @@ -1,18 +1,18 @@ """basic tests of lazy loaded attributes""" -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions from sqlalchemy.orm import * from testlib import * -from fixtures import * +from testlib.fixtures import * from query import QueryTest +import datetime -class LazyTest(QueryTest): +class LazyTest(FixtureTest): keep_mappers = False + keep_data = True - def setup_mappers(self): - pass - def test_basic(self): mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses), lazy=True) @@ -21,10 +21,11 @@ class LazyTest(QueryTest): q = sess.query(User) assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all() + @testing.uses_deprecated('SessionContext') def test_bindstosession(self): """test that lazy loaders use the mapper's contextual session if the parent instance is not in a session, and that an error is raised if no contextual session""" - + from sqlalchemy.ext.sessioncontext import SessionContext ctx = SessionContext(create_session) m = mapper(User, users, properties = dict( @@ -46,6 +47,7 @@ class LazyTest(QueryTest): u = q.filter(users.c.id == 7).first() sess.expunge(u) assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u + assert False except exceptions.InvalidRequestError, err: assert "not bound to a Session, and no contextual session" in str(err) @@ -57,21 +59,21 @@ class LazyTest(QueryTest): assert [ User(id=7, addresses=[ Address(id=1) - ]), + ]), User(id=8, addresses=[ Address(id=3, email_address='ed@bettyboop.com'), Address(id=4, email_address='ed@lala.com'), Address(id=2, email_address='ed@wood.com') - ]), + ]), User(id=9, addresses=[ Address(id=5) - ]), + ]), User(id=10, addresses=[]) ] == q.all() - + def test_orderby_secondary(self): """tests that a regular mapper select on a single table can order by a relation to a second table""" - + mapper(Address, addresses) mapper(User, users, properties = dict( @@ -84,13 +86,13 @@ class LazyTest(QueryTest): Address(id=2, email_address='ed@wood.com'), Address(id=3, email_address='ed@bettyboop.com'), Address(id=4, email_address='ed@lala.com'), - ]), + ]), User(id=9, addresses=[ Address(id=5) - ]), + ]), User(id=7, addresses=[ Address(id=1) - ]), + ]), ] == l def test_orderby_desc(self): @@ -103,15 +105,15 @@ class LazyTest(QueryTest): assert [ User(id=7, addresses=[ Address(id=1) - ]), + ]), User(id=8, addresses=[ Address(id=2, email_address='ed@wood.com'), Address(id=4, email_address='ed@lala.com'), Address(id=3, email_address='ed@bettyboop.com'), - ]), + ]), User(id=9, addresses=[ Address(id=5) - ]), + ]), User(id=10, addresses=[]) ] == sess.query(User).all() @@ -128,9 +130,10 @@ class LazyTest(QueryTest): assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True) assert not class_mapper(Address)._is_orphan(user.addresses[0]) + def test_limit(self): """test limit operations combined with lazy-load relationships.""" - + mapper(Item, items) mapper(Order, orders, properties={ 'items':relation(Item, secondary=order_items, lazy=True) @@ -143,10 +146,10 @@ class LazyTest(QueryTest): sess = create_session() q = sess.query(User) - if testbase.db.engine.name == 'mssql': + if testing.against('maxdb', 'mssql'): l = q.limit(2).all() assert fixtures.user_all_result[:2] == l - else: + else: l = q.limit(2).offset(1).all() assert fixtures.user_all_result[1:3] == l @@ -184,7 +187,7 @@ class LazyTest(QueryTest): closedorders = alias(orders, 'closedorders') mapper(Address, addresses) - + mapper(User, users, properties = dict( addresses = relation(Address, lazy = True), open_orders = relation(mapper(Order, openorders, entity_name='open'), primaryjoin = and_(openorders.c.isopen == 1, users.c.id==openorders.c.user_id), lazy=True), @@ -212,16 +215,21 @@ class LazyTest(QueryTest): closed_orders = [Order(id=2)] ), User(id=10) - + ] == q.all() + sess = create_session() + user = sess.query(User).get(7) + assert [Order(id=1), Order(id=5)] == create_session().query(Order, entity_name='closed').with_parent(user, property='closed_orders').all() + assert [Order(id=3)] == create_session().query(Order, entity_name='open').with_parent(user, property='open_orders').all() + def test_many_to_many(self): mapper(Keyword, keywords) mapper(Item, items, properties = dict( keywords = relation(Keyword, secondary=item_keywords, lazy=True), )) - + q = create_session().query(Item) assert fixtures.item_keyword_result == q.all() @@ -238,21 +246,21 @@ class LazyTest(QueryTest): mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy=True, primaryjoin=pj) )) - + sess = create_session() - + # load address a1 = sess.query(Address).filter_by(email_address="ed@wood.com").one() - + # load user that is attached to the address u1 = sess.query(User).get(8) - + def go(): # lazy load of a1.user should get it from the session assert a1.user is u1 - self.assert_sql_count(testbase.db, go, 0) + self.assert_sql_count(testing.db, go, 0) clear_mappers() - + def test_many_to_one(self): mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy=True) @@ -262,10 +270,126 @@ class LazyTest(QueryTest): a = q.filter(addresses.c.id==1).one() assert a.user is not None - + u1 = sess.query(User).get(7) - + assert a.user is u1 + def test_backrefs_dont_lazyload(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user') + }) + mapper(Address, addresses) + sess = create_session() + ad = sess.query(Address).filter_by(id=1).one() + assert ad.user.id == 7 + def go(): + ad.user = None + assert ad.user is None + self.assert_sql_count(testing.db, go, 0) + + u1 = sess.query(User).filter_by(id=7).one() + def go(): + assert ad not in u1.addresses + self.assert_sql_count(testing.db, go, 1) + + sess.expire(u1, ['addresses']) + def go(): + assert ad in u1.addresses + self.assert_sql_count(testing.db, go, 1) + + sess.expire(u1, ['addresses']) + ad2 = Address() + def go(): + ad2.user = u1 + assert ad2.user is u1 + self.assert_sql_count(testing.db, go, 0) + + def go(): + assert ad2 in u1.addresses + self.assert_sql_count(testing.db, go, 1) + +class M2OGetTest(FixtureTest): + keep_mappers = False + keep_data = True + + def test_m2o_noload(self): + """test that a NULL foreign key doesn't trigger a lazy load""" + mapper(User, users) + + mapper(Address, addresses, properties={ + 'user':relation(User) + }) + + sess = create_session() + ad1 = Address(email_address='somenewaddress', id=12) + sess.save(ad1) + sess.flush() + sess.clear() + + ad2 = sess.query(Address).get(1) + ad3 = sess.query(Address).get(ad1.id) + def go(): + # one lazy load + assert ad2.user.name == 'jack' + # no lazy load + assert ad3.user is None + self.assert_sql_count(testing.db, go, 1) + +class CorrelatedTest(ORMTest): + keep_mappers = False + keep_data = False + + def define_tables(self, meta): + global user_t, stuff + + user_t = Table('users', meta, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) + + stuff = Table('stuff', meta, + Column('id', Integer, primary_key=True), + Column('date', Date), + Column('user_id', Integer, ForeignKey('users.id'))) + + def insert_data(self): + user_t.insert().execute( + {'id':1, 'name':'user1'}, + {'id':2, 'name':'user2'}, + {'id':3, 'name':'user3'}, + ) + + stuff.insert().execute( + {'id':1, 'user_id':1, 'date':datetime.date(2007, 10, 15)}, + {'id':2, 'user_id':1, 'date':datetime.date(2007, 12, 15)}, + {'id':3, 'user_id':1, 'date':datetime.date(2007, 11, 15)}, + {'id':4, 'user_id':2, 'date':datetime.date(2008, 1, 15)}, + {'id':5, 'user_id':3, 'date':datetime.date(2007, 6, 15)}, + ) + + def test_correlated_lazyload(self): + class User(Base): + pass + + class Stuff(Base): + pass + + mapper(Stuff, stuff) + + stuff_view = select([stuff.c.id]).where(stuff.c.user_id==user_t.c.id).correlate(user_t).order_by(desc(stuff.c.date)).limit(1) + + mapper(User, user_t, properties={ + 'stuff':relation(Stuff, primaryjoin=and_(user_t.c.id==stuff.c.user_id, stuff.c.id==(stuff_view.as_scalar()))) + }) + + sess = create_session() + + self.assertEquals(sess.query(User).all(), [ + User(name='user1', stuff=[Stuff(date=datetime.date(2007, 12, 15), id=2)]), + User(name='user2', stuff=[Stuff(id=4, date=datetime.date(2008, 1 , 15))]), + User(name='user3', stuff=[Stuff(id=5, date=datetime.date(2007, 6, 15))]) + ]) + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/orm/lazytest1.py b/test/orm/lazytest1.py index b5296120b3..90cbbe2086 100644 --- a/test/orm/lazytest1.py +++ b/test/orm/lazytest1.py @@ -1,80 +1,83 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * -class LazyTest(AssertMixin): +class LazyTest(TestBase, AssertsExecutionResults): def setUpAll(self): global info_table, data_table, rel_table, metadata - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) info_table = Table('infos', metadata, - Column('pk', Integer, primary_key=True), - Column('info', String)) + Column('pk', Integer, primary_key=True), + Column('info', String(128))) data_table = Table('data', metadata, - Column('data_pk', Integer, primary_key=True), - Column('info_pk', Integer, ForeignKey(info_table.c.pk)), - Column('timeval', Integer), - Column('data_val', String)) + Column('data_pk', Integer, primary_key=True), + Column('info_pk', Integer, + ForeignKey(info_table.c.pk)), + Column('timeval', Integer), + Column('data_val', String(128))) rel_table = Table('rels', metadata, - Column('rel_pk', Integer, primary_key=True), - Column('info_pk', Integer, ForeignKey(info_table.c.pk)), - Column('start', Integer), - Column('finish', Integer)) + Column('rel_pk', Integer, primary_key=True), + Column('info_pk', Integer, + ForeignKey(info_table.c.pk)), + Column('start', Integer), + Column('finish', Integer)) metadata.create_all() info_table.insert().execute( - {'pk':1, 'info':'pk_1_info'}, - {'pk':2, 'info':'pk_2_info'}, - {'pk':3, 'info':'pk_3_info'}, - {'pk':4, 'info':'pk_4_info'}, - {'pk':5, 'info':'pk_5_info'}) + {'pk':1, 'info':'pk_1_info'}, + {'pk':2, 'info':'pk_2_info'}, + {'pk':3, 'info':'pk_3_info'}, + {'pk':4, 'info':'pk_4_info'}, + {'pk':5, 'info':'pk_5_info'}) rel_table.insert().execute( - {'rel_pk':1, 'info_pk':1, 'start':10, 'finish':19}, - {'rel_pk':2, 'info_pk':1, 'start':100, 'finish':199}, - {'rel_pk':3, 'info_pk':2, 'start':20, 'finish':29}, - {'rel_pk':4, 'info_pk':3, 'start':13, 'finish':23}, - {'rel_pk':5, 'info_pk':5, 'start':15, 'finish':25}) + {'rel_pk':1, 'info_pk':1, 'start':10, 'finish':19}, + {'rel_pk':2, 'info_pk':1, 'start':100, 'finish':199}, + {'rel_pk':3, 'info_pk':2, 'start':20, 'finish':29}, + {'rel_pk':4, 'info_pk':3, 'start':13, 'finish':23}, + {'rel_pk':5, 'info_pk':5, 'start':15, 'finish':25}) data_table.insert().execute( - {'data_pk':1, 'info_pk':1, 'timeval':11, 'data_val':'11_data'}, - {'data_pk':2, 'info_pk':1, 'timeval':9, 'data_val':'9_data'}, - {'data_pk':3, 'info_pk':1, 'timeval':13, 'data_val':'13_data'}, - {'data_pk':4, 'info_pk':2, 'timeval':23, 'data_val':'23_data'}, - {'data_pk':5, 'info_pk':2, 'timeval':13, 'data_val':'13_data'}, - {'data_pk':6, 'info_pk':1, 'timeval':15, 'data_val':'15_data'}) + {'data_pk':1, 'info_pk':1, 'timeval':11, 'data_val':'11_data'}, + {'data_pk':2, 'info_pk':1, 'timeval':9, 'data_val':'9_data'}, + {'data_pk':3, 'info_pk':1, 'timeval':13, 'data_val':'13_data'}, + {'data_pk':4, 'info_pk':2, 'timeval':23, 'data_val':'23_data'}, + {'data_pk':5, 'info_pk':2, 'timeval':13, 'data_val':'13_data'}, + {'data_pk':6, 'info_pk':1, 'timeval':15, 'data_val':'15_data'}) def tearDownAll(self): metadata.drop_all() - + def testone(self): - """tests a lazy load which has multiple join conditions, including two that are against - the same column in the child table""" + """Tests a lazy load which has multiple join conditions. + + ...including two that are against the same column in the child table. + """ + class Information(object): - pass + pass class Relation(object): - pass + pass class Data(object): - pass + pass session = create_session() - + mapper(Data, data_table) mapper(Relation, rel_table, properties={ - 'datas': relation(Data, - primaryjoin=and_(rel_table.c.info_pk==data_table.c.info_pk, - data_table.c.timeval >= rel_table.c.start, - data_table.c.timeval <= rel_table.c.finish), - foreignkey=data_table.c.info_pk) - } - - ) + primaryjoin=and_( + rel_table.c.info_pk == + data_table.c.info_pk, + data_table.c.timeval >= rel_table.c.start, + data_table.c.timeval <= rel_table.c.finish), + foreign_keys=[data_table.c.info_pk])}) mapper(Information, info_table, properties={ 'rels': relation(Relation) }) @@ -84,7 +87,5 @@ class LazyTest(AssertMixin): assert len(info.rels) == 2 assert len(info.rels[0].datas) == 3 -if __name__ == "__main__": - testbase.main() - - +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py index 8b310f86c5..ca6410533d 100644 --- a/test/orm/manytomany.py +++ b/test/orm/manytomany.py @@ -1,8 +1,8 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * - +from sqlalchemy import exceptions class Place(object): '''represents a place''' @@ -17,7 +17,7 @@ class PlaceThingy(object): '''represents a thingy attached to a Place''' def __init__(self, name=None): self.name = name - + class Transition(object): '''represents a transition''' def __init__(self, name=None): @@ -26,7 +26,7 @@ class Transition(object): self.outputs = [] def __repr__(self): return object.__repr__(self)+ " " + repr(self.inputs) + " " + repr(self.outputs) - + class M2MTest(ORMTest): def define_tables(self, metadata): global place @@ -47,7 +47,7 @@ class M2MTest(ORMTest): Column('place_id', Integer, ForeignKey('place.place_id'), nullable=False), Column('name', String(30), nullable=False) ) - + # association table #1 global place_input place_input = Table('place_input', metadata, @@ -68,6 +68,23 @@ class M2MTest(ORMTest): Column('pl2_id', Integer, ForeignKey('place.place_id')), ) + def testerror(self): + mapper(Place, place, properties={ + 'transitions':relation(Transition, secondary=place_input, backref='places') + }) + mapper(Transition, transition, properties={ + 'places':relation(Place, secondary=place_input, backref='transitions') + }) + try: + compile_mappers() + assert False + except exceptions.ArgumentError, e: + assert str(e) in [ + "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'", + "Error creating backref 'places' on relation 'Place.transitions (Transition)': property of that name exists on mapper 'Mapper|Transition|transition'" + ] + + def testcircular(self): """tests a many-to-many relationship from a table to itself.""" @@ -100,7 +117,7 @@ class M2MTest(ORMTest): sess.flush() sess.clear() - l = sess.query(Place).select(order_by=place.c.place_id) + l = sess.query(Place).order_by(place.c.place_id).all() (p1, p2, p3, p4, p5, p6, p7) = l assert p1.places == [p2,p3,p5] assert p5.places == [p6] @@ -124,7 +141,7 @@ class M2MTest(ORMTest): Place.mapper = mapper(Place, place, properties = { 'thingies':relation(mapper(PlaceThingy, place_thingy), lazy=False) }) - + Transition.mapper = mapper(Transition, transition, properties = dict( inputs = relation(Place.mapper, place_output, lazy=False), outputs = relation(Place.mapper, place_input, lazy=False), @@ -140,13 +157,12 @@ class M2MTest(ORMTest): sess.flush() sess.clear() - r = sess.query(Transition).select() - self.assert_result(r, Transition, - {'name':'transition1', - 'inputs' : (Place, [{'name':'place1'}]), - 'outputs' : (Place, [{'name':'place2'}, {'name':'place3'}]) - } - ) + r = sess.query(Transition).all() + self.assert_unordered_result(r, Transition, + {'name': 'transition1', + 'inputs': (Place, [{'name':'place1'}]), + 'outputs': (Place, [{'name':'place2'}, {'name':'place3'}]) + }) def testbidirectional(self): """tests a many-to-many backrefs""" @@ -174,7 +190,7 @@ class M2MTest(ORMTest): sess = create_session() [sess.save(x) for x in [t1,t2,t3,p1,p2,p3]] sess.flush() - + self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])}) self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])}) @@ -189,7 +205,7 @@ class M2MTest2(ORMTest): Column('student_id', String(20), ForeignKey('student.name'),primary_key=True), Column('course_id', String(20), ForeignKey('course.name'), primary_key=True)) - def testcircular(self): + def testcircular(self): class Student(object): def __init__(self, name=''): self.name = name @@ -213,12 +229,12 @@ class M2MTest2(ORMTest): sess.save(s1) sess.flush() sess.clear() - s = sess.query(Student).get_by(name='Student1') - c = sess.query(Course).get_by(name='Course3') + s = sess.query(Student).filter_by(name='Student1').one() + c = sess.query(Course).filter_by(name='Course3').one() self.assert_(len(s.courses) == 3) del s.courses[1] self.assert_(len(s.courses) == 2) - + def test_delete(self): """test that many-to-many table gets cleared out with deletion from the backref side""" class Student(object): @@ -244,36 +260,36 @@ class M2MTest2(ORMTest): sess.delete(s1) sess.flush() assert enrolTbl.count().scalar() == 0 - + class M2MTest3(ORMTest): def define_tables(self, metadata): global c, c2a1, c2a2, b, a - c = Table('c', metadata, + c = Table('c', metadata, Column('c1', Integer, primary_key = True), Column('c2', String(20)), ) - a = Table('a', metadata, + a = Table('a', metadata, Column('a1', Integer, primary_key=True), Column('a2', String(20)), Column('c1', Integer, ForeignKey('c.c1')) ) - c2a1 = Table('ctoaone', metadata, + c2a1 = Table('ctoaone', metadata, Column('c1', Integer, ForeignKey('c.c1')), Column('a1', Integer, ForeignKey('a.a1')) ) - c2a2 = Table('ctoatwo', metadata, + c2a2 = Table('ctoatwo', metadata, Column('c1', Integer, ForeignKey('c.c1')), Column('a1', Integer, ForeignKey('a.a1')) ) - b = Table('b', metadata, + b = Table('b', metadata, Column('b1', Integer, primary_key=True), Column('a1', Integer, ForeignKey('a.a1')), Column('b2', Boolean) ) - + def testbasic(self): class C(object):pass class A(object):pass @@ -281,13 +297,13 @@ class M2MTest3(ORMTest): mapper(B, b) - mapper(A, a, + mapper(A, a, properties = { 'tbs' : relation(B, primaryjoin=and_(b.c.a1==a.c.a1, b.c.b2 == True), lazy=False), } ) - mapper(C, c, + mapper(C, c, properties = { 'a1s' : relation(A, secondary=c2a1, lazy=False), 'a2s' : relation(A, secondary=c2a2, lazy=False) @@ -297,6 +313,5 @@ class M2MTest3(ORMTest): o1 = create_session().query(C).get(1) -if __name__ == "__main__": - testbase.main() - +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index b72a10516a..7dce096145 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1,16 +1,17 @@ """tests general mapper operations with an emphasis on selecting/loading""" -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions, sql from sqlalchemy.orm import * -import sqlalchemy.exceptions as exceptions from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt from testlib import * +from testlib import fixtures from testlib.tables import * import testlib.tables as tables -class MapperSuperTest(AssertMixin): +class MapperSuperTest(TestBase, AssertsExecutionResults): def setUpAll(self): tables.create() tables.data() @@ -20,10 +21,10 @@ class MapperSuperTest(AssertMixin): clear_mappers() def setUp(self): pass - + class MapperTest(MapperSuperTest): - def testpropconflict(self): + def test_propconflict(self): """test that a backref created against an existing mapper with a property name conflict raises a decent error message""" mapper(Address, addresses) @@ -31,106 +32,80 @@ class MapperTest(MapperSuperTest): properties={ 'addresses':relation(Address, backref='email_address') }) - try: - class_mapper(Address) - class_mapper(User) - assert False - except exceptions.ArgumentError: - pass + self.assertRaises(exceptions.ArgumentError, compile_mappers) + + def test_prop_accessor(self): + mapper(User, users) + self.assertRaises(NotImplementedError, getattr, class_mapper(User), 'properties') - def testbadcascade(self): + def test_badcascade(self): mapper(Address, addresses) - try: - mapper(User, users, properties={'addresses':relation(Address, cascade="fake, all, delete-orphan")}) - assert False - except exceptions.ArgumentError, e: - assert str(e) == "Invalid cascade option 'fake'" - - def testcolumnprefix(self): - mapper(User, users, column_prefix='_') + self.assertRaises(exceptions.ArgumentError, relation, Address, cascade="fake, all, delete-orphan") + + def test_columnprefix(self): + mapper(User, users, column_prefix='_', properties={ + 'user_name':synonym('_user_name') + }) + s = create_session() u = s.get(User, 7) assert u._user_name=='jack' - assert u._user_id ==7 - assert not hasattr(u, 'user_name') - - def testrefresh(self): - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), backref='user')}) - s = create_session() - u = s.get(User, 7) - u.user_name = 'foo' - a = Address() - assert object_session(a) is None - u.addresses.append(a) + assert u._user_id ==7 + u2 = s.query(User).filter_by(user_name='jack').one() + assert u is u2 - self.assert_(a in u.addresses) + def test_no_pks(self): + s = select([users.c.user_name]).alias('foo') + self.assertRaises(exceptions.ArgumentError, mapper, User, s) + + def test_recompile_on_othermapper(self): + """test the global '_new_mappers' flag such that a compile + trigger on an already-compiled mapper still triggers a check against all mappers.""" - s.refresh(u) - - # its refreshed, so not dirty - self.assert_(u not in s.dirty) + from sqlalchemy.orm import mapperlib - # username is back to the DB - self.assert_(u.user_name == 'jack') + mapper(User, users) + compile_mappers() + assert mapperlib._new_mappers is False - self.assert_(a not in u.addresses) + m = mapper(Address, addresses, properties={'user':relation(User, backref="addresses")}) - u.user_name = 'foo' - u.addresses.append(a) - # now its dirty - self.assert_(u in s.dirty) - self.assert_(u.user_name == 'foo') - self.assert_(a in u.addresses) - s.expire(u) - - # get the attribute, it refreshes - self.assert_(u.user_name == 'jack') - self.assert_(a not in u.addresses) - + assert m._Mapper__props_init is False + assert mapperlib._new_mappers is True + u = User() + assert User.addresses + assert mapperlib._new_mappers is False - def testexpirecascade(self): - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), cascade="all, refresh-expire")}) - s = create_session() - u = s.get(User, 8) - u.addresses[0].email_address = 'someotheraddress' - s.expire(u) - assert u.addresses[0].email_address == 'ed@wood.com' - - def testrefreshwitheager(self): - """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders""" - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)}) - s = create_session() - u = s.get(User, 8) - assert len(u.addresses) == 3 - s.refresh(u) - assert len(u.addresses) == 3 + def test_compileonsession(self): + m = mapper(User, users) + session = create_session() + session.connection(m) + def test_incompletecolumns(self): + """test loading from a select which does not contain all columns""" + mapper(Address, addresses) s = create_session() - u = s.get(User, 8) - assert len(u.addresses) == 3 - s.expire(u) - assert len(u.addresses) == 3 - - def testbadconstructor(self): + a = s.query(Address).from_statement(select([addresses.c.address_id, addresses.c.user_id])).first() + assert a.user_id == 7 + assert a.address_id == 1 + # email address auto-defers + assert 'email_addres' not in a.__dict__ + assert a.email_address == 'jack@bean.com' + + def test_badconstructor(self): """test that if the construction of a mapped class fails, the instnace does not get placed in the session""" class Foo(object): def __init__(self, one, two): pass mapper(Foo, users) sess = create_session() - try: - Foo('one', _sa_session=sess) - assert False - except: - assert len(list(sess)) == 0 - try: - Foo('one') - assert False - except TypeError, e: - pass + self.assertRaises(TypeError, Foo, 'one', _sa_session=sess) + assert len(list(sess)) == 0 + self.assertRaises(TypeError, Foo, 'one') - def testconstructorexceptions(self): - """test that exceptions raised in the mapped class are not masked by sa decorations""" + @testing.uses_deprecated('SessionContext', 'SessionContextExt') + def test_constructorexceptions(self): + """test that exceptions raised in the mapped class are not masked by sa decorations""" ex = AssertionError('oops') sess = create_session() @@ -150,96 +125,281 @@ class MapperTest(MapperSuperTest): def bad_expunge(foo): raise Exception("this exception should be stated as a warning") - import warnings - warnings.filterwarnings("always", r".*this exception should be stated as a warning") - sess.expunge = bad_expunge try: Foo(_sa_session=sess) assert False except Exception, e: - assert e is ex - - def testrefresh_lazy(self): - """test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems""" - s = create_session() - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) - q2 = s.query(User).options(lazyload('addresses')) - u = q2.selectfirst(users.c.user_id==8) - def go(): - s.refresh(u) - self.assert_sql_count(testbase.db, go, 1) + assert isinstance(e, exceptions.SAWarning) - def testexpire(self): - """test the expire function""" - s = create_session() - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)}) - u = s.get(User, 7) - assert(len(u.addresses) == 1) - u.user_name = 'foo' - del u.addresses[0] - s.expire(u) - # test plain expire - self.assert_(u.user_name =='jack') - self.assert_(len(u.addresses) == 1) - - # we're changing the database here, so if this test fails in the middle, - # it'll screw up the other tests which are hardcoded to 7/'jack' - u.user_name = 'foo' - s.flush() - # change the value in the DB - users.update(users.c.user_id==7, values=dict(user_name='jack')).execute() - s.expire(u) - # object isnt refreshed yet, using dict to bypass trigger - self.assert_(u.__dict__.get('user_name') != 'jack') - # do a select - s.query(User).select() - # test that it refreshed - self.assert_(u.__dict__['user_name'] == 'jack') - - # object should be back to normal now, - # this should *not* produce a SELECT statement (not tested here though....) - self.assert_(u.user_name =='jack') - - def testrefresh2(self): - """test a hang condition that was occuring on expire/refresh""" - - s = create_session() - m1 = mapper(Address, addresses) + clear_mappers() - m2 = mapper(User, users, properties = dict(addresses=relation(Address,private=True,lazy=False)) ) - assert m1._Mapper__is_compiled is False - assert m2._Mapper__is_compiled is False - -# compile_mappers() - print "NEW USER" - u=User() - print "NEW USER DONE" - assert m2._Mapper__is_compiled is True - u.user_name='Justin' - a = Address() - a.address_id=17 # to work around the hardcoded IDs in this test suite.... - u.addresses.append(a) - s.flush() - s.clear() - u = s.query(User).selectfirst() - print u.user_name + # test that TypeError is raised for illegal constructor args, + # whether or not explicit __init__ is present [ticket:908] + class Foo(object): + def __init__(self): + pass + class Bar(object): + pass + + mapper(Foo, users) + mapper(Bar, addresses) + try: + Foo(x=5) + assert False + except TypeError: + assert True - #ok so far - s.expire(u) #hangs when - print u.user_name #this line runs + try: + Bar(x=5) + assert False + except TypeError: + assert True - s.refresh(u) #hangs - - def testprops(self): - """tests the various attributes of the properties attached to classes""" + def test_props(self): m = mapper(User, users, properties = { 'addresses' : relation(mapper(Address, addresses)) }).compile() self.assert_(User.addresses.property is m.get_property('addresses')) - - def testrecursiveselectby(self): + def test_compileonprop(self): + mapper(User, users, properties = { + 'addresses' : relation(mapper(Address, addresses)) + }) + User.addresses.any(Address.email_address=='foo@bar.com') + clear_mappers() + + mapper(User, users, properties = { + 'addresses' : relation(mapper(Address, addresses)) + }) + assert (User.user_id==3).compare(users.c.user_id==3) + + clear_mappers() + + class Foo(User):pass + mapper(User, users) + mapper(Foo, addresses, inherits=User) + assert getattr(Foo().__class__, 'user_name').impl is not None + + def test_compileon_getprops(self): + m =mapper(User, users) + + assert not m.compiled + assert list(m.iterate_properties) + assert m.compiled + clear_mappers() + + m= mapper(User, users) + assert not m.compiled + assert m.get_property('user_name') + assert m.compiled + + def test_add_property(self): + assert_col = [] + + class User(object): + def _get_user_name(self): + assert_col.append(('get', self._user_name)) + return self._user_name + def _set_user_name(self, name): + assert_col.append(('set', name)) + self._user_name = name + user_name = property(_get_user_name, _set_user_name) + + def _uc_user_name(self): + if self._user_name is None: + return None + return self._user_name.upper() + uc_user_name = property(_uc_user_name) + uc_user_name2 = property(_uc_user_name) + + m = mapper(User, users) + mapper(Address, addresses) + + class UCComparator(PropComparator): + def __eq__(self, other): + cls = self.prop.parent.class_ + col = getattr(cls, 'user_name') + if other is None: + return col == None + else: + return func.upper(col) == func.upper(other) + + m.add_property('_user_name', deferred(users.c.user_name)) + m.add_property('user_name', synonym('_user_name')) + m.add_property('addresses', relation(Address)) + m.add_property('uc_user_name', comparable_property(UCComparator)) + m.add_property('uc_user_name2', comparable_property( + UCComparator, User.uc_user_name2)) + + sess = create_session(transactional=True) + assert sess.query(User).get(7) + + u = sess.query(User).filter_by(user_name='jack').one() + + def go(): + self.assert_result([u], User, user_address_result[0]) + assert u.user_name == 'jack' + assert u.uc_user_name == 'JACK' + assert u.uc_user_name2 == 'JACK' + assert assert_col == [('get', 'jack')], str(assert_col) + self.assert_sql_count(testing.db, go, 2) + + u.name = 'ed' + u3 = User() + u3.user_name = 'some user' + sess.save(u3) + sess.flush() + sess.rollback() + + def test_replace_property(self): + m = mapper(User, users) + m.add_property('_user_name',users.c.user_name) + m.add_property('user_name', synonym('_user_name', proxy=True)) + + sess = create_session() + u = sess.query(User).filter_by(user_name='jack').one() + assert u._user_name == 'jack' + assert u.user_name == 'jack' + u.user_name = 'jacko' + assert m._columntoproperty[users.c.user_name] is m.get_property('_user_name') + + clear_mappers() + + m = mapper(User, users) + m.add_property('user_name', synonym('_user_name', map_column=True)) + + sess.clear() + u = sess.query(User).filter_by(user_name='jack').one() + assert u._user_name == 'jack' + assert u.user_name == 'jack' + u.user_name = 'jacko' + assert m._columntoproperty[users.c.user_name] is m.get_property('_user_name') + + def test_synonym_replaces_backref(self): + assert_calls = [] + class Address(object): + def _get_user(self): + assert_calls.append("get") + return self._user + def _set_user(self, user): + assert_calls.append("set") + self._user = user + user = property(_get_user, _set_user) + + # synonym is created against nonexistent prop + mapper(Address, addresses, properties={ + 'user':synonym('_user') + }) + compile_mappers() + + # later, backref sets up the prop + mapper(User, users, properties={ + 'addresses':relation(Address, backref='_user') + }) + + sess = create_session() + u1 = sess.query(User).get(7) + u2 = sess.query(User).get(8) + # comparaison ops need to work + a1 = sess.query(Address).filter(Address.user==u1).one() + assert a1.address_id == 1 + a1.user = u2 + assert a1.user is u2 + self.assertEquals(assert_calls, ["set", "get"]) + + def test_self_ref_syn(self): + t = Table('nodes', MetaData(), + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('nodes.id'))) + + class Node(object): + pass + + mapper(Node, t, properties={ + '_children':relation(Node, backref=backref('_parent', remote_side=t.c.id)), + 'children':synonym('_children'), + 'parent':synonym('_parent') + }) + + n1 = Node() + n2 = Node() + n1.children.append(n2) + assert n2.parent is n2._parent is n1 + assert n1.children[0] is n1._children[0] is n2 + self.assertEquals(str(Node.parent == n2), ":param_1 = nodes.parent_id") + + def test_illegal_non_primary(self): + mapper(User, users) + mapper(Address, addresses) + try: + mapper(User, users, non_primary=True, properties={ + 'addresses':relation(Address) + }).compile() + assert False + except exceptions.ArgumentError, e: + assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e) + + def test_illegal_non_primary_2(self): + try: + mapper(User, users, non_primary=True) + assert False + except exceptions.InvalidRequestError, e: + assert "Configure a primary mapper first" in str(e) + + def test_propfilters(self): + t = Table('person', MetaData(), + Column('id', Integer, primary_key=True), + Column('type', String(128)), + Column('name', String(128)), + Column('employee_number', Integer), + Column('boss_id', Integer, ForeignKey('person.id')), + Column('vendor_id', Integer)) + + class Person(object): pass + class Vendor(Person): pass + class Employee(Person): pass + class Manager(Employee): pass + class Hoho(object): pass + class Lala(object): pass + + p_m = mapper(Person, t, polymorphic_on=t.c.type, + include_properties=('id', 'type', 'name')) + e_m = mapper(Employee, inherits=p_m, polymorphic_identity='employee', + properties={ + 'boss': relation(Manager, backref='peon') + }, + exclude_properties=('vendor_id',)) + + m_m = mapper(Manager, inherits=e_m, polymorphic_identity='manager', + include_properties=()) + + v_m = mapper(Vendor, inherits=p_m, polymorphic_identity='vendor', + exclude_properties=('boss_id', 'employee_number')) + h_m = mapper(Hoho, t, include_properties=('id', 'type', 'name')) + l_m = mapper(Lala, t, exclude_properties=('vendor_id', 'boss_id'), + column_prefix="p_") + + p_m.compile() + #compile_mappers() + + def assert_props(cls, want): + have = set([n for n in dir(cls) if not n.startswith('_')]) + want = set(want) + want.add('c') + self.assert_(have == want, repr(have) + " " + repr(want)) + + assert_props(Person, ['id', 'name', 'type']) + assert_props(Employee, ['boss', 'boss_id', 'employee_number', + 'id', 'name', 'type']) + assert_props(Manager, ['boss', 'boss_id', 'employee_number', 'peon', + 'id', 'name', 'type']) + assert_props(Vendor, ['vendor_id', 'id', 'name', 'type']) + assert_props(Hoho, ['id', 'name', 'type']) + assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type']) + + @testing.uses_deprecated('//select_by', '//join_via', '//list') + def test_recursive_select_by_deprecated(self): """test that no endless loop occurs when traversing for select_by""" m = mapper(User, users, properties={ 'orders':relation(mapper(Order, orders), backref='user'), @@ -248,15 +408,39 @@ class MapperTest(MapperSuperTest): q = create_session().query(m) q.select_by(email_address='foo') - def testmappingtojoin(self): + def test_mappingtojoin(self): """test mapping to a join""" usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) m = mapper(User, usersaddresses, primary_key=[users.c.user_id]) q = create_session().query(m) - l = q.select() + l = q.all() self.assert_result(l, User, *user_result[0:2]) - - def testmappingtoouterjoin(self): + + def test_mappingtojoinnopk(self): + metadata = MetaData() + account_ids_table = Table('account_ids', metadata, + Column('account_id', Integer, primary_key=True), + Column('username', String(20))) + account_stuff_table = Table('account_stuff', metadata, + Column('account_id', Integer, ForeignKey('account_ids.account_id')), + Column('credit', Numeric)) + class A(object):pass + m = mapper(A, account_ids_table.join(account_stuff_table)) + m.compile() + assert account_ids_table in m._pks_by_table + assert account_stuff_table not in m._pks_by_table + metadata.create_all(testing.db) + try: + sess = create_session(bind=testing.db) + a = A() + sess.save(a) + sess.flush() + assert testing.db.execute(account_ids_table.count()).scalar() == 1 + assert testing.db.execute(account_stuff_table.count()).scalar() == 0 + finally: + metadata.drop_all(testing.db) + + def test_mappingtoouterjoin(self): """test mapping to an outer join, with a composite primary key that allows nulls""" result = [ {'user_id' : 7, 'address_id' : 1}, @@ -265,15 +449,31 @@ class MapperTest(MapperSuperTest): {'user_id' : 8, 'address_id' : 4}, {'user_id' : 9, 'address_id':None} ] - + j = join(users, addresses, isouter=True) m = mapper(User, j, allow_null_pks=True, primary_key=[users.c.user_id, addresses.c.address_id]) q = create_session().query(m) - l = q.select() + l = q.all() self.assert_result(l, User, *result) - - def testcustomjoin(self): + + def test_customjoin(self): + """Tests that select_from totally replace the FROM parameters.""" + + m = mapper(User, users, properties={ + 'orders':relation(mapper(Order, orders, properties={ + 'items':relation(mapper(Item, orderitems)) + })) + }) + + q = create_session().query(m) + l = (q.select_from(users.join(orders).join(orderitems)). + filter(orderitems.c.item_name=='item 4')) + + self.assert_result(l, User, user_result[0]) + + @testing.uses_deprecated('//select') + def test_customjoin_deprecated(self): """test that the from_obj parameter to query.select() can be used to totally replace the FROM parameters of the generated query.""" @@ -286,52 +486,64 @@ class MapperTest(MapperSuperTest): q = create_session().query(m) l = q.select((orderitems.c.item_name=='item 4'), from_obj=[users.join(orders).join(orderitems)]) self.assert_result(l, User, user_result[0]) - - def testorderby(self): + + def test_orderby(self): """test ordering at the mapper and query level""" + # TODO: make a unit test out of these various combinations -# m = mapper(User, users, order_by=desc(users.c.user_name)) + #m = mapper(User, users, order_by=desc(users.c.user_name)) mapper(User, users, order_by=None) -# mapper(User, users) - -# l = create_session().query(User).select(order_by=[desc(users.c.user_name), asc(users.c.user_id)]) - l = create_session().query(User).select() -# l = create_session().query(User).select(order_by=[]) -# l = create_session().query(User).select(order_by=None) - - - @testing.unsupported('firebird') - def testfunction(self): - """test mapping to a SELECT statement that has functions in it.""" - s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')], - users.c.user_id==addresses.c.user_id, group_by=[c for c in users.c]).alias('myselect') + #mapper(User, users) + + #l = create_session().query(User).select(order_by=[desc(users.c.user_name), asc(users.c.user_id)]) + l = create_session().query(User).all() + #l = create_session().query(User).select(order_by=[]) + #l = create_session().query(User).select(order_by=None) + + + @testing.unsupported('firebird') + def test_function(self): + """Test mapping to a SELECT statement that has functions in it.""" + + s = select([users, + (users.c.user_id * 2).label('concat'), + func.count(addresses.c.address_id).label('count')], + users.c.user_id == addresses.c.user_id, + group_by=[c for c in users.c]).alias('myselect') + mapper(User, s) sess = create_session() - l = sess.query(User).select() + l = sess.query(User).all() for u in l: print "User", u.user_id, u.user_name, u.concat, u.count assert l[0].concat == l[0].user_id * 2 == 14 assert l[1].concat == l[1].user_id * 2 == 16 - @testing.unsupported('firebird') - def testcount(self): + @testing.unsupported('firebird') + def test_count(self): """test the count function on Query. - + (why doesnt this work on firebird?)""" mapper(User, users) q = create_session().query(User) self.assert_(q.count()==3) - self.assert_(q.count(users.c.user_id.in_(8,9))==2) + self.assert_(q.count(users.c.user_id.in_([8,9]))==2) + + @testing.unsupported('firebird') + @testing.uses_deprecated('//count_by', '//join_by', '//join_via') + def test_count_by_deprecated(self): + mapper(User, users) + q = create_session().query(User) self.assert_(q.count_by(user_name='fred')==1) - def testmanytomany_count(self): + def test_manytomany_count(self): mapper(Item, orderitems, properties = dict( keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True), )) q = create_session().query(Item) assert q.join('keywords').distinct().count(Keyword.c.name=="red") == 2 - def testoverride(self): + def test_override(self): # assert that overriding a column raises an error try: m = mapper(User, users, properties = { @@ -340,13 +552,13 @@ class MapperTest(MapperSuperTest): self.assert_(False, "should have raised ArgumentError") except exceptions.ArgumentError, e: self.assert_(True) - + clear_mappers() # assert that allow_column_override cancels the error m = mapper(User, users, properties = { 'user_name' : relation(mapper(Address, addresses)) }, allow_column_override=True) - + clear_mappers() # assert that the column being named else where also cancels the error m = mapper(User, users, properties = { @@ -354,51 +566,190 @@ class MapperTest(MapperSuperTest): 'foo' : users.c.user_name, }) - def testsynonym(self): + def test_synonym(self): sess = create_session() + + assert_col = [] + class extendedproperty(property): + attribute = 123 + def __getitem__(self, key): + return 'value' + + class User(object): + def _get_user_name(self): + assert_col.append(('get', self.user_name)) + return self.user_name + def _set_user_name(self, name): + assert_col.append(('set', name)) + self.user_name = name + uname = extendedproperty(_get_user_name, _set_user_name) + mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy = True), - uname = synonym('user_name', proxy=True), + addresses = relation(mapper(Address, addresses), lazy=True), + uname = synonym('user_name'), adlist = synonym('addresses', proxy=True), adname = synonym('addresses') )) - - u = sess.query(User).get_by(uname='jack') + + assert hasattr(User, 'adlist') + assert hasattr(User, 'adname') # as of 0.4.2, synonyms always create a property + + # test compile + assert not isinstance(User.uname == 'jack', bool) + + u = sess.query(User).filter(User.uname=='jack').one() self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1])) - assert hasattr(u, 'adlist') - assert not hasattr(u, 'adname') - - addr = sess.query(Address).get_by(address_id=user_address_result[0]['addresses'][1][0]['address_id']) - u = sess.query(User).get_by(adname=addr) - u2 = sess.query(User).get_by(adlist=addr) + addr = sess.query(Address).filter_by(address_id=user_address_result[0]['addresses'][1][0]['address_id']).one() + u = sess.query(User).filter_by(adname=addr).one() + u2 = sess.query(User).filter_by(adlist=addr).one() + assert u is u2 - + assert u not in sess.dirty u.uname = "some user name" + assert len(assert_col) > 0 + assert assert_col == [('set', 'some user name')], str(assert_col) assert u.uname == "some user name" + assert assert_col == [('set', 'some user name'), ('get', 'some user name')], str(assert_col) assert u.user_name == "some user name" assert u in sess.dirty - def testsynonymoptions(self): + assert User.uname.attribute == 123 + assert User.uname['key'] == 'value' + + def test_column_synonyms(self): + """test new-style synonyms which automatically instrument properties, set up aliased column, etc.""" + + sess = create_session() + + assert_col = [] + class User(object): + def _get_user_name(self): + assert_col.append(('get', self._user_name)) + return self._user_name + def _set_user_name(self, name): + assert_col.append(('set', name)) + self._user_name = name + user_name = property(_get_user_name, _set_user_name) + + mapper(Address, addresses) + try: + mapper(User, users, properties = { + 'addresses':relation(Address, lazy=True), + 'not_user_name':synonym('_user_name', map_column=True) + }) + User.not_user_name + assert False + except exceptions.ArgumentError, e: + assert str(e) == "Can't compile synonym '_user_name': no column on table 'users' named 'not_user_name'" + + clear_mappers() + + mapper(Address, addresses) + mapper(User, users, properties = { + 'addresses':relation(Address, lazy=True), + 'user_name':synonym('_user_name', map_column=True) + }) + + # test compile + assert not isinstance(User.user_name == 'jack', bool) + + assert hasattr(User, 'user_name') + assert hasattr(User, '_user_name') + + u = sess.query(User).filter(User.user_name == 'jack').one() + assert u.user_name == 'jack' + u.user_name = 'foo' + assert u.user_name == 'foo' + assert assert_col == [('get', 'jack'), ('set', 'foo'), ('get', 'foo')] + + def test_comparable(self): + class extendedproperty(property): + attribute = 123 + def __getitem__(self, key): + return 'value' + + class UCComparator(PropComparator): + def __eq__(self, other): + cls = self.prop.parent.class_ + col = getattr(cls, 'user_name') + if other is None: + return col == None + else: + return func.upper(col) == func.upper(other) + + def map_(with_explicit_property): + class User(object): + @extendedproperty + def uc_user_name(self): + if self.user_name is None: + return None + return self.user_name.upper() + if with_explicit_property: + args = (UCComparator, User.uc_user_name) + else: + args = (UCComparator,) + + mapper(User, users, properties=dict( + uc_user_name = comparable_property(*args))) + return User + + for User in (map_(True), map_(False)): + sess = create_session() + sess.begin() + q = sess.query(User) + + assert hasattr(User, 'user_name') + assert hasattr(User, 'uc_user_name') + + # test compile + assert not isinstance(User.uc_user_name == 'jack', bool) + u = q.filter(User.uc_user_name=='JACK').one() + + assert u.uc_user_name == "JACK" + assert u not in sess.dirty + + u.user_name = "some user name" + assert u.user_name == "some user name" + assert u in sess.dirty + assert u.uc_user_name == "SOME USER NAME" + + sess.flush() + sess.clear() + + q = sess.query(User) + u2 = q.filter(User.user_name=='some user name').one() + u3 = q.filter(User.uc_user_name=='SOME USER NAME').one() + + assert u2 is u3 + + assert User.uc_user_name.attribute == 123 + assert User.uc_user_name['key'] == 'value' + sess.rollback() + +class OptionsTest(MapperSuperTest): + @testing.fails_on('maxdb') + def test_synonymoptions(self): sess = create_session() mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True), adlist = synonym('addresses', proxy=True) )) - + def go(): - u = sess.query(User).options(eagerload('adlist')).get_by(user_name='jack') + u = sess.query(User).options(eagerload('adlist')).filter_by(user_name='jack').one() self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1])) - self.assert_sql_count(testbase.db, go, 1) - - def testextensionoptions(self): + self.assert_sql_count(testing.db, go, 1) + + @testing.uses_deprecated('//select_by') + def test_extension_options(self): sess = create_session() class ext1(MapperExtension): def populate_instance(self, mapper, selectcontext, row, instance, **flags): """test options at the Mapper._instance level""" instance.TEST = "hello world" - return EXT_PASS + return EXT_CONTINUE mapper(User, users, extension=ext1(), properties={ 'addresses':relation(mapper(Address, addresses), lazy=False) }) @@ -409,7 +760,7 @@ class MapperTest(MapperSuperTest): def populate_instance(self, mapper, selectcontext, row, instance, **flags): """test options at the Mapper._instance level""" instance.TEST_2 = "also hello world" - return EXT_PASS + return EXT_CONTINUE l = sess.query(User).options(extension(testext())).select_by(x=5) assert l == "HI" l = sess.query(User).options(extension(testext())).get(7) @@ -418,83 +769,85 @@ class MapperTest(MapperSuperTest): assert l.TEST_2 == "also hello world" assert not hasattr(l.addresses[0], 'TEST') assert not hasattr(l.addresses[0], 'TEST2') - - def testeageroptions(self): + + def test_eageroptions(self): """tests that a lazy relation can be upgraded to an eager relation via the options method""" sess = create_session() mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy = True) + addresses = relation(mapper(Address, addresses)) )) - l = sess.query(User).options(eagerload('addresses')).select() + l = sess.query(User).options(eagerload('addresses')).all() def go(): self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 0) + self.assert_sql_count(testing.db, go, 0) - def testeageroptionswithlimit(self): + @testing.fails_on('maxdb') + def test_eageroptionswithlimit(self): sess = create_session() mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True) )) - u = sess.query(User).options(eagerload('addresses')).get_by(user_id=8) + u = sess.query(User).options(eagerload('addresses')).filter_by(user_id=8).one() def go(): assert u.user_id == 8 assert len(u.addresses) == 3 - self.assert_sql_count(testbase.db, go, 0) + self.assert_sql_count(testing.db, go, 0) sess.clear() - + # test that eager loading doesnt modify parent mapper def go(): - u = sess.query(User).get_by(user_id=8) + u = sess.query(User).filter_by(user_id=8).one() assert u.user_id == 8 assert len(u.addresses) == 3 - assert "tbl_row_count" not in self.capture_sql(testbase.db, go) - - def testlazyoptionswithlimit(self): + assert "tbl_row_count" not in self.capture_sql(testing.db, go) + + @testing.fails_on('maxdb') + def test_lazyoptionswithlimit(self): sess = create_session() mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy = False) + addresses = relation(mapper(Address, addresses), lazy=False) )) - u = sess.query(User).options(lazyload('addresses')).get_by(user_id=8) + u = sess.query(User).options(lazyload('addresses')).filter_by(user_id=8).one() def go(): assert u.user_id == 8 assert len(u.addresses) == 3 - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) - def testeagerdegrade(self): + def test_eagerdegrade(self): """tests that an eager relation automatically degrades to a lazy relation if eager columns are not available""" sess = create_session() usermapper = mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy = False) - )).compile() + addresses = relation(mapper(Address, addresses), lazy=False) + )) # first test straight eager load, 1 statement def go(): - l = sess.query(usermapper).select() + l = sess.query(usermapper).all() self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) sess.clear() - + # then select just from users. run it into instances. # then assert the data, which will launch 3 more lazy loads # (previous users in session fell out of scope and were removed from session's identity map) def go(): r = users.select().execute() - l = usermapper.instances(r, sess) + l = sess.query(usermapper).instances(r) self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 4) - + self.assert_sql_count(testing.db, go, 4) + clear_mappers() sess.clear() - + # test with a deeper set of eager loads. when we first load the three # users, they will have no addresses or orders. the number of lazy loads when - # traversing the whole thing will be three for the addresses and three for the + # traversing the whole thing will be three for the addresses and three for the # orders. # (previous users in session fell out of scope and were removed from session's identity map) usermapper = mapper(User, users, @@ -511,47 +864,33 @@ class MapperTest(MapperSuperTest): # first test straight eager load, 1 statement def go(): - l = sess.query(usermapper).select() + l = sess.query(usermapper).all() self.assert_result(l, User, *user_all_result) - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) sess.clear() - + # then select just from users. run it into instances. # then assert the data, which will launch 6 more lazy loads def go(): r = users.select().execute() - l = usermapper.instances(r, sess) + l = sess.query(usermapper).instances(r) self.assert_result(l, User, *user_all_result) - self.assert_sql_count(testbase.db, go, 7) - - - def testlazyoptions(self): + self.assert_sql_count(testing.db, go, 7) + + + def test_lazyoptions(self): """tests that an eager relation can be upgraded to a lazy relation via the options method""" sess = create_session() mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy = False) + addresses = relation(mapper(Address, addresses), lazy=False) )) - l = sess.query(User).options(lazyload('addresses')).select() + l = sess.query(User).options(lazyload('addresses')).all() def go(): self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 3) - - def testlatecompile(self): - """tests mappers compiling late in the game""" - - mapper(User, users, properties = {'orders': relation(Order)}) - mapper(Item, orderitems, properties={'keywords':relation(Keyword, secondary=itemkeywords)}) - mapper(Keyword, keywords) - mapper(Order, orders, properties={'items':relation(Item)}) - - sess = create_session() - u = sess.query(User).select() - def go(): - print u[0].orders[1].items[0].keywords[1] - self.assert_sql_count(testbase.db, go, 3) + self.assert_sql_count(testing.db, go, 3) - def testdeepoptions(self): + def test_deepoptions(self): mapper(User, users, properties = { 'orders': relation(mapper(Order, orders, properties = { @@ -560,86 +899,85 @@ class MapperTest(MapperSuperTest): })) })) }) - + sess = create_session() - + # eagerload nothing. - u = sess.query(User).select() + u = sess.query(User).all() def go(): print u[0].orders[1].items[0].keywords[1] - self.assert_sql_count(testbase.db, go, 3) + self.assert_sql_count(testing.db, go, 3) sess.clear() - - - print "-------MARK----------" + # eagerload orders.items.keywords; eagerload_all() implies eager load of orders, orders.items q2 = sess.query(User).options(eagerload_all('orders.items.keywords')) - u = q2.select() + u = q2.all() def go(): print u[0].orders[1].items[0].keywords[1] - print "-------MARK2----------" - self.assert_sql_count(testbase.db, go, 0) + self.assert_sql_count(testing.db, go, 0) sess.clear() # same thing, with separate options calls q2 = sess.query(User).options(eagerload('orders')).options(eagerload('orders.items')).options(eagerload('orders.items.keywords')) - u = q2.select() + u = q2.all() def go(): print u[0].orders[1].items[0].keywords[1] - print "-------MARK3----------" - self.assert_sql_count(testbase.db, go, 0) - print "-------MARK4----------" + self.assert_sql_count(testing.db, go, 0) sess.clear() - + + self.assertRaisesMessage(exceptions.ArgumentError, + r"Can't find entity Mapper\|Order\|orders in Query. Current list: \['Mapper\|User\|users'\]", + sess.query(User).options, eagerload('items', Order) + ) + # eagerload "keywords" on items. it will lazy load "orders", then lazy load # the "items" on the order, but on "items" it will eager load the "keywords" - print "-------MARK5----------" q3 = sess.query(User).options(eagerload('orders.items.keywords')) - u = q3.select() - self.assert_sql_count(testbase.db, go, 2) - - + u = q3.all() + self.assert_sql_count(testing.db, go, 2) + + class DeferredTest(MapperSuperTest): - def testbasic(self): + def test_basic(self): """tests a basic "deferred" load""" - + m = mapper(Order, orders, properties={ 'description':deferred(orders.c.description) }) - + o = Order() self.assert_(o.description is None) q = create_session().query(m) def go(): - l = q.select() + l = q.all() o2 = l[2] print o2.description - orderby = str(orders.default_order_by()[0].compile(bind=testbase.db)) - self.assert_sql(testbase.db, go, [ + orderby = str(orders.default_order_by()[0].compile(bind=testing.db)) + self.assert_sql(testing.db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), - ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) + ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :param_1", {'param_1':3}) ]) - def testunsaved(self): + def test_unsaved(self): """test that deferred loading doesnt kick in when just PK cols are set""" m = mapper(Order, orders, properties={ 'description':deferred(orders.c.description) }) - + sess = create_session() o = Order() sess.save(o) o.order_id = 7 def go(): o.description = "some description" - self.assert_sql_count(testbase.db, go, 0) + self.assert_sql_count(testing.db, go, 0) - def testunsavedgroup(self): + def test_unsavedgroup(self): """test that deferred loading doesnt kick in when just PK cols are set""" m = mapper(Order, orders, properties={ 'description':deferred(orders.c.description, group='primary'), @@ -652,21 +990,21 @@ class DeferredTest(MapperSuperTest): o.order_id = 7 def go(): o.description = "some description" - self.assert_sql_count(testbase.db, go, 0) - - def testsave(self): + self.assert_sql_count(testing.db, go, 0) + + def test_save(self): m = mapper(Order, orders, properties={ 'description':deferred(orders.c.description) }) - + sess = create_session() q = sess.query(m) - l = q.select() + l = q.all() o2 = l[2] o2.isopen = 1 sess.flush() - - def testgroup(self): + + def test_group(self): """tests deferred load with a group""" m = mapper(Order, orders, properties = { 'userident':deferred(orders.c.user_id, group='primary'), @@ -676,31 +1014,51 @@ class DeferredTest(MapperSuperTest): sess = create_session() q = sess.query(m) def go(): - l = q.select() + l = q.all() o2 = l[2] print o2.opened, o2.description, o2.userident assert o2.opened == 1 assert o2.userident == 7 assert o2.description == 'order 3' - orderby = str(orders.default_order_by()[0].compile(testbase.db)) - self.assert_sql(testbase.db, go, [ + orderby = str(orders.default_order_by()[0].compile(testing.db)) + self.assert_sql(testing.db, go, [ ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}), - ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) + ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :param_1", {'param_1':3}) ]) - - o2 = q.select()[2] + + o2 = q.all()[2] # assert o2.opened == 1 assert o2.description == 'order 3' assert o2 not in sess.dirty o2.description = 'order 3' def go(): sess.flush() - self.assert_sql_count(testbase.db, go, 0) - - def testcommitsstate(self): + self.assert_sql_count(testing.db, go, 0) + + def test_preserve_changes(self): + """test that the deferred load operation doesn't revert modifications on attributes""" + + mapper(Order, orders, properties = { + 'userident':deferred(orders.c.user_id, group='primary'), + 'description':deferred(orders.c.description, group='primary'), + 'opened':deferred(orders.c.isopen, group='primary') + }) + sess = create_session() + o = sess.query(Order).get(3) + assert 'userident' not in o.__dict__ + o.description = 'somenewdescription' + assert o.description == 'somenewdescription' + def go(): + assert o.opened == 1 + self.assert_sql_count(testing.db, go, 1) + assert o.description == 'somenewdescription' + assert o in sess.dirty + + + def test_commitsstate(self): """test that when deferred elements are loaded via a group, they get the proper CommittedState and dont result in changes being committed""" - + m = mapper(Order, orders, properties = { 'userident':deferred(orders.c.user_id, group='primary'), 'description':deferred(orders.c.description, group='primary'), @@ -708,7 +1066,7 @@ class DeferredTest(MapperSuperTest): }) sess = create_session() q = sess.query(m) - o2 = q.select()[2] + o2 = q.all()[2] # this will load the group of attributes assert o2.description == 'order 3' assert o2 not in sess.dirty @@ -717,33 +1075,33 @@ class DeferredTest(MapperSuperTest): def go(): # therefore the flush() shouldnt actually issue any SQL sess.flush() - self.assert_sql_count(testbase.db, go, 0) - - def testoptions(self): + self.assert_sql_count(testing.db, go, 0) + + def test_options(self): """tests using options on a mapper to create deferred and undeferred columns""" m = mapper(Order, orders) sess = create_session() q = sess.query(m) q2 = q.options(defer('user_id')) def go(): - l = q2.select() + l = q2.all() print l[2].user_id - - orderby = str(orders.default_order_by()[0].compile(testbase.db)) - self.assert_sql(testbase.db, go, [ + + orderby = str(orders.default_order_by()[0].compile(testing.db)) + self.assert_sql(testing.db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), - ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) + ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :param_1", {'param_1':3}) ]) sess.clear() q3 = q2.options(undefer('user_id')) def go(): - l = q3.select() + l = q3.all() print l[3].user_id - self.assert_sql(testbase.db, go, [ + self.assert_sql(testing.db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ]) - def testundefergroup(self): + def test_undefergroup(self): """tests undefer_group()""" m = mapper(Order, orders, properties = { 'userident':deferred(orders.c.user_id, group='primary'), @@ -753,19 +1111,36 @@ class DeferredTest(MapperSuperTest): sess = create_session() q = sess.query(m) def go(): - l = q.options(undefer_group('primary')).select() + l = q.options(undefer_group('primary')).all() o2 = l[2] print o2.opened, o2.description, o2.userident assert o2.opened == 1 assert o2.userident == 7 assert o2.description == 'order 3' - orderby = str(orders.default_order_by()[0].compile(testbase.db)) - self.assert_sql(testbase.db, go, [ + orderby = str(orders.default_order_by()[0].compile(testing.db)) + self.assert_sql(testing.db, go, [ ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen, orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}), ]) - - def testdeepoptions(self): + def test_locates_col(self): + """test that manually adding a col to the result undefers the column""" + mapper(Order, orders, properties={ + 'description':deferred(orders.c.description) + }) + + sess = create_session() + o1 = sess.query(Order).first() + def go(): + assert o1.description == 'order 1' + self.assert_sql_count(testing.db, go, 1) + + sess = create_session() + o1 = sess.query(Order).add_column(orders.c.description).first()[0] + def go(): + assert o1.description == 'order 1' + self.assert_sql_count(testing.db, go, 0) + + def test_deepoptions(self): m = mapper(User, users, properties={ 'orders':relation(mapper(Order, orders, properties={ 'items':relation(mapper(Item, orderitems, properties={ @@ -775,19 +1150,19 @@ class DeferredTest(MapperSuperTest): }) sess = create_session() q = sess.query(m) - l = q.select() + l = q.all() item = l[0].orders[1].items[1] def go(): print item.item_name - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) self.assert_(item.item_name == 'item 4') sess.clear() q2 = q.options(undefer('orders.items.item_name')) - l = q2.select() + l = q2.all() item = l[0].orders[1].items[1] def go(): print item.item_name - self.assert_sql_count(testbase.db, go, 0) + self.assert_sql_count(testing.db, go, 0) self.assert_(item.item_name == 'item 4') class CompositeTypesTest(ORMTest): @@ -797,8 +1172,8 @@ class CompositeTypesTest(ORMTest): Column('id', Integer, primary_key=True), Column('version_id', Integer, primary_key=True), Column('name', String(30))) - - edges = Table('edges', metadata, + + edges = Table('edges', metadata, Column('id', Integer, primary_key=True), Column('graph_id', Integer, nullable=False), Column('graph_version_id', Integer, nullable=False), @@ -814,8 +1189,8 @@ class CompositeTypesTest(ORMTest): def __init__(self, x, y): self.x = x self.y = y - def __colset__(self): - return [self.x, self.y] + def __composite_values__(self): + return [self.x, self.y] def __eq__(self, other): return other.x == self.x and other.y == self.y def __ne__(self, other): @@ -827,7 +1202,7 @@ class CompositeTypesTest(ORMTest): def __init__(self, start, end): self.start = start self.end = end - + mapper(Graph, graphs, properties={ 'edges':relation(Edge) }) @@ -835,7 +1210,7 @@ class CompositeTypesTest(ORMTest): 'start':composite(Point, edges.c.x1, edges.c.y1), 'end':composite(Point, edges.c.x2, edges.c.y2) }) - + sess = create_session() g = Graph() g.id = 1 @@ -844,13 +1219,13 @@ class CompositeTypesTest(ORMTest): g.edges.append(Edge(Point(14, 5), Point(2, 7))) sess.save(g) sess.flush() - + sess.clear() g2 = sess.query(Graph).get([g.id, g.version_id]) for e1, e2 in zip(g.edges, g2.edges): assert e1.start == e2.start assert e1.end == e2.end - + g2.edges[1].end = Point(18, 4) sess.flush() sess.clear() @@ -864,63 +1239,64 @@ class CompositeTypesTest(ORMTest): assert sess.query(Edge).get(g2.edges[1].id).end == Point(19, 5) g.edges[1].end = Point(19, 5) - + sess.clear() def go(): g2 = sess.query(Graph).options(eagerload('edges')).get([g.id, g.version_id]) for e1, e2 in zip(g.edges, g2.edges): assert e1.start == e2.start assert e1.end == e2.end - self.assert_sql_count(testbase.db, go, 1) - + self.assert_sql_count(testing.db, go, 1) + # test comparison of CompositeProperties to their object instances g = sess.query(Graph).get([1, 1]) assert sess.query(Edge).filter(Edge.start==Point(3, 4)).one() is g.edges[0] - + assert sess.query(Edge).filter(Edge.start!=Point(3, 4)).first() is g.edges[1] assert sess.query(Edge).filter(Edge.start==None).all() == [] - - + + def test_pk(self): """test using a composite type as a primary key""" - + class Version(object): def __init__(self, id, version): self.id = id self.version = version - def __colset__(self): - return [self.id, self.version] + def __composite_values__(self): + # a tuple this time + return (self.id, self.version) def __eq__(self, other): return other.id == self.id and other.version == self.version def __ne__(self, other): return not self.__eq__(other) - + class Graph(object): def __init__(self, version): self.version = version - + mapper(Graph, graphs, properties={ 'version':composite(Version, graphs.c.id, graphs.c.version_id) }) - + sess = create_session() g = Graph(Version(1, 1)) sess.save(g) sess.flush() - + sess.clear() g2 = sess.query(Graph).get([1, 1]) assert g.version == g2.version sess.clear() - + g2 = sess.query(Graph).get(Version(1, 1)) assert g.version == g2.version - - - + + + class NoLoadTest(MapperSuperTest): - def testbasic(self): + def test_basic(self): """tests a basic one-to-many lazy load""" m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy=None) @@ -928,44 +1304,379 @@ class NoLoadTest(MapperSuperTest): q = create_session().query(m) l = [None] def go(): - x = q.select(users.c.user_id == 7) + x = q.filter(users.c.user_id == 7).all() x[0].addresses l[0] = x - self.assert_sql_count(testbase.db, go, 1) - + self.assert_sql_count(testing.db, go, 1) + self.assert_result(l[0], User, {'user_id' : 7, 'addresses' : (Address, [])}, ) - def testoptions(self): + + def test_options(self): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy=None) )) q = create_session().query(m).options(lazyload('addresses')) l = [None] def go(): - x = q.select(users.c.user_id == 7) + x = q.filter(users.c.user_id == 7).all() x[0].addresses l[0] = x - self.assert_sql_count(testbase.db, go, 2) - + self.assert_sql_count(testing.db, go, 2) + self.assert_result(l[0], User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, ) -class MapperExtensionTest(MapperSuperTest): - def testcreateinstance(self): +class MapperExtensionTest(TestBase): + def setUpAll(self): + tables.create() + + global methods, Ext + + methods = [] + class Ext(MapperExtension): - def create_instance(self, *args, **kwargs): - return User() - m = mapper(Address, addresses) - m = mapper(User, users, extension=Ext(), properties = dict( - addresses = relation(Address, lazy=True), - )) + def load(self, query, *args, **kwargs): + methods.append('load') + return EXT_CONTINUE + + def get(self, query, *args, **kwargs): + methods.append('get') + return EXT_CONTINUE + + def translate_row(self, mapper, context, row): + methods.append('translate_row') + return EXT_CONTINUE + + def create_instance(self, mapper, selectcontext, row, class_): + methods.append('create_instance') + return EXT_CONTINUE + + def append_result(self, mapper, selectcontext, row, instance, result, **flags): + methods.append('append_result') + return EXT_CONTINUE + + def populate_instance(self, mapper, selectcontext, row, instance, **flags): + methods.append('populate_instance') + return EXT_CONTINUE + + def before_insert(self, mapper, connection, instance): + methods.append('before_insert') + return EXT_CONTINUE + + def after_insert(self, mapper, connection, instance): + methods.append('after_insert') + return EXT_CONTINUE + + def before_update(self, mapper, connection, instance): + methods.append('before_update') + return EXT_CONTINUE + + def after_update(self, mapper, connection, instance): + methods.append('after_update') + return EXT_CONTINUE + + def before_delete(self, mapper, connection, instance): + methods.append('before_delete') + return EXT_CONTINUE + + def after_delete(self, mapper, connection, instance): + methods.append('after_delete') + return EXT_CONTINUE + + def tearDown(self): + clear_mappers() + methods[:] = [] + tables.delete() + + def tearDownAll(self): + tables.drop() + + def test_basic(self): + """test that common user-defined methods get called.""" + mapper(User, users, extension=Ext()) + sess = create_session() + u = User() + sess.save(u) + sess.flush() + u = sess.query(User).load(u.user_id) + sess.clear() + u = sess.query(User).get(u.user_id) + u.user_name = 'foobar' + sess.flush() + sess.delete(u) + sess.flush() + self.assertEquals(methods, + ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 'translate_row', + 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete'] + ) + + def test_inheritance(self): + # test using inheritance + class AdminUser(User): + pass + + mapper(User, users, extension=Ext()) + mapper(AdminUser, addresses, inherits=User) + + sess = create_session() + am = AdminUser() + sess.save(am) + sess.flush() + am = sess.query(AdminUser).load(am.user_id) + sess.clear() + am = sess.query(AdminUser).get(am.user_id) + am.user_name = 'foobar' + sess.flush() + sess.delete(am) + sess.flush() + self.assertEquals(methods, + ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', + 'translate_row', 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete']) + + def test_after_with_no_changes(self): + # test that after_update is called even if no cols were updated + + mapper(Item, orderitems, extension=Ext() , properties={ + 'keywords':relation(Keyword, secondary=itemkeywords) + }) + mapper(Keyword, keywords, extension=Ext() ) + + sess = create_session() + i1 = Item() + k1 = Keyword() + sess.save(i1) + sess.save(k1) + sess.flush() + self.assertEquals(methods, ['before_insert', 'after_insert', 'before_insert', 'after_insert']) + + methods[:] = [] + i1.keywords.append(k1) + sess.flush() + self.assertEquals(methods, ['before_update', 'after_update']) + + + def test_inheritance_with_dupes(self): + # test using inheritance, same extension on both mappers + class AdminUser(User): + pass + + ext = Ext() + mapper(User, users, extension=ext) + mapper(AdminUser, addresses, inherits=User, extension=ext) + + sess = create_session() + am = AdminUser() + sess.save(am) + sess.flush() + am = sess.query(AdminUser).load(am.user_id) + sess.clear() + am = sess.query(AdminUser).get(am.user_id) + am.user_name = 'foobar' + sess.flush() + sess.delete(am) + sess.flush() + self.assertEquals(methods, + ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 'translate_row', + 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete'] + ) + +class RequirementsTest(ORMTest): + """Tests the contract for user classes.""" + + def define_tables(self, metadata): + global t1, t2, t3, t4, t5, t6 + + t1 = Table('ht1', metadata, + Column('id', Integer, primary_key=True), + Column('value', String(10))) + t2 = Table('ht2', metadata, + Column('id', Integer, primary_key=True), + Column('ht1_id', Integer, ForeignKey('ht1.id')), + Column('value', String(10))) + t3 = Table('ht3', metadata, + Column('id', Integer, primary_key=True), + Column('value', String(10))) + t4 = Table('ht4', metadata, + Column('ht1_id', Integer, ForeignKey('ht1.id'), + primary_key=True), + Column('ht3_id', Integer, ForeignKey('ht3.id'), + primary_key=True)) + t5 = Table('ht5', metadata, + Column('ht1_id', Integer, ForeignKey('ht1.id'), + primary_key=True), + ) + t6 = Table('ht6', metadata, + Column('ht1a_id', Integer, ForeignKey('ht1.id'), + primary_key=True), + Column('ht1b_id', Integer, ForeignKey('ht1.id'), + primary_key=True), + Column('value', String(10))) + + def test_baseclass(self): + class OldStyle: + pass + + self.assertRaises(exceptions.ArgumentError, mapper, OldStyle, t1) + + class NoWeakrefSupport(str): + pass + + # TODO: is weakref support detectable without an instance? + #self.assertRaises(exceptions.ArgumentError, mapper, NoWeakrefSupport, t2) + + def test_comparison_overrides(self): + """Simple tests to ensure users can supply comparison __methods__. + + The suite-level test --options are better suited to detect + problems- they add selected __methods__ across the board on all + ORM tests. This test simply shoves a variety of operations + through the ORM to catch basic regressions early in a standard + test run. + """ + + # adding these methods directly to each class to avoid decoration + # by the testlib decorators. + class H1(object): + def __init__(self, value='abc'): + self.value = value + def __nonzero__(self): + return False + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + return False + class H2(object): + def __init__(self, value='abc'): + self.value = value + def __nonzero__(self): + return False + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + return False + class H3(object): + def __init__(self, value='abc'): + self.value = value + def __nonzero__(self): + return False + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + return False + class H6(object): + def __init__(self, value='abc'): + self.value = value + def __nonzero__(self): + return False + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + return False + + + mapper(H1, t1, properties={ + 'h2s': relation(H2, backref='h1'), + 'h3s': relation(H3, secondary=t4, backref='h1s'), + 'h1s': relation(H1, secondary=t5, backref='parent_h1'), + 't6a': relation(H6, backref='h1a', + primaryjoin=t1.c.id==t6.c.ht1a_id), + 't6b': relation(H6, backref='h1b', + primaryjoin=t1.c.id==t6.c.ht1b_id), + }) + mapper(H2, t2) + mapper(H3, t3) + mapper(H6, t6) + + s = create_session() + for i in range(3): + h1 = H1() + s.save(h1) + + h1.h2s.append(H2()) + h1.h3s.extend([H3(), H3()]) + h1.h1s.append(H1()) + + s.flush() + self.assertEquals(t1.count().scalar(), 4) + + h6 = H6() + h6.h1a = h1 + h6.h1b = h1 + + h6 = H6() + h6.h1a = h1 + h6.h1b = x = H1() + assert x in s + + h6.h1b.h2s.append(H2()) + + s.flush() + + h1.h2s.extend([H2(), H2()]) + s.flush() + + h1s = s.query(H1).options(eagerload('h2s')).all() + self.assertEqual(len(h1s), 5) + + self.assert_unordered_result(h1s, H1, + {'h2s': []}, + {'h2s': []}, + {'h2s': (H2, [{'value': 'abc'}, + {'value': 'abc'}, + {'value': 'abc'}])}, + {'h2s': []}, + {'h2s': (H2, [{'value': 'abc'}])}) + + h1s = s.query(H1).options(eagerload('h3s')).all() + + self.assertEqual(len(h1s), 5) + h1s = s.query(H1).options(eagerload_all('t6a.h1b'), + eagerload('h2s'), + eagerload_all('h3s.h1s')).all() + self.assertEqual(len(h1s), 5) + +class NoEqFoo(object): + def __init__(self, data): + self.data = data + def __eq__(self, other): + raise NotImplementedError() + def __ne__(self, other): + raise NotImplementedError() + +class ScalarRequirementsTest(ORMTest): + def define_tables(self, metadata): + import pickle + global t1 + t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), + Column('data', PickleType(pickler=pickle)) # dont use cPickle due to import weirdness + ) + + def test_correct_comparison(self): + + class H1(fixtures.Base): + pass + + mapper(H1, t1) + + h1 = H1(data=NoEqFoo('12345')) + s = create_session() + s.save(h1) + s.flush() + s.clear() + h1 = s.get(H1, h1.id) + assert h1.data.data == '12345' - q = create_session().query(m) - l = q.select(); - self.assert_result(l, User, *user_address_result) - -if __name__ == "__main__": - testbase.main() +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/memusage.py b/test/orm/memusage.py index 26da7c010d..1851639edf 100644 --- a/test/orm/memusage.py +++ b/test/orm/memusage.py @@ -1,76 +1,283 @@ -import testbase +import testenv; testenv.configure_for_tests() import gc from sqlalchemy import MetaData, Integer, String, ForeignKey from sqlalchemy.orm import mapper, relation, clear_mappers, create_session -from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.orm.mapper import Mapper, _mapper_registry +from sqlalchemy.orm.session import _sessions from testlib import * +from testlib.fixtures import Base -class A(object):pass -class B(object):pass +class A(Base):pass +class B(Base):pass -class MapperCleanoutTest(AssertMixin): - """test that clear_mappers() removes everything related to the class. - - does not include classes that use the assignmapper extension.""" - - def test_mapper_cleanup(self): - for x in range(0, 5): - self.do_test() +def profile_memory(func): + # run the test 50 times. if length of gc.get_objects() + # keeps growing, assert false + def profile(*args): + samples = [] + for x in range(0, 50): + func(*args) gc.collect() - for o in gc.get_objects(): - if isinstance(o, Mapper): - # the classes in the 'tables' package have assign_mapper called on them - # which is particularly sticky - # if getattr(tables, o.class_.__name__, None) is o.class_: - # continue - # well really we are just testing our own classes here - if (o.class_ not in [A,B]): - continue - assert False - assert True + samples.append(len(gc.get_objects())) + print "sample gc sizes:", samples + + assert len(_sessions) == 0 + + # TODO: this test only finds pure "growing" tests. + # if a drop is detected, it's assumed that GC is able + # to reduce memory. better methodology would + # make this more accurate. + for i, x in enumerate(samples): + if i < len(samples) - 1 and x < samples[i+1]: + continue + else: + return + assert False, repr(samples) + return profile + +def assert_no_mappers(): + clear_mappers() + gc.collect() + assert len(_mapper_registry) == 0 + +class EnsureZeroed(TestBase, AssertsExecutionResults): + def setUp(self): + _sessions.clear() + _mapper_registry.clear() - def do_test(self): - metadata = MetaData(testbase.db) +class MemUsageTest(EnsureZeroed): - table1 = Table("mytable", metadata, + def test_session(self): + metadata = MetaData(testing.db) + + table1 = Table("mytable", metadata, Column('col1', Integer, primary_key=True), Column('col2', String(30)) ) - table2 = Table("mytable2", metadata, + table2 = Table("mytable2", metadata, Column('col1', Integer, primary_key=True), Column('col2', String(30)), Column('col3', Integer, ForeignKey("mytable.col1")) ) - - metadata.create_all() + metadata.create_all() m1 = mapper(A, table1, properties={ - "bs":relation(B) + "bs":relation(B, cascade="all, delete") }) m2 = mapper(B, table2) m3 = mapper(A, table1, non_primary=True) - - sess = create_session() - a1 = A() - a2 = A() - a3 = A() - a1.bs.append(B()) - a1.bs.append(B()) - a3.bs.append(B()) - for x in [a1,a2,a3]: - sess.save(x) - sess.flush() - sess.clear() - - alist = sess.query(A).select() - for a in alist: - print "A", a, "BS", [b for b in a.bs] - + + @profile_memory + def go(): + sess = create_session() + a1 = A(col2="a1") + a2 = A(col2="a2") + a3 = A(col2="a3") + a1.bs.append(B(col2="b1")) + a1.bs.append(B(col2="b2")) + a3.bs.append(B(col2="b3")) + for x in [a1,a2,a3]: + sess.save(x) + sess.flush() + sess.clear() + + alist = sess.query(A).all() + self.assertEquals( + [ + A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), + A(col2="a2", bs=[]), + A(col2="a3", bs=[B(col2="b3")]) + ], + alist) + + for a in alist: + sess.delete(a) + sess.flush() + go() + metadata.drop_all() - clear_mappers() - + del m1, m2, m3 + assert_no_mappers() + + def test_mapper_reset(self): + metadata = MetaData(testing.db) + + table1 = Table("mytable", metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)) + ) + + table2 = Table("mytable2", metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)), + Column('col3', Integer, ForeignKey("mytable.col1")) + ) + + @profile_memory + def go(): + m1 = mapper(A, table1, properties={ + "bs":relation(B) + }) + m2 = mapper(B, table2) + + m3 = mapper(A, table1, non_primary=True) + + sess = create_session() + a1 = A(col2="a1") + a2 = A(col2="a2") + a3 = A(col2="a3") + a1.bs.append(B(col2="b1")) + a1.bs.append(B(col2="b2")) + a3.bs.append(B(col2="b3")) + for x in [a1,a2,a3]: + sess.save(x) + sess.flush() + sess.clear() + + alist = sess.query(A).all() + self.assertEquals( + [ + A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), + A(col2="a2", bs=[]), + A(col2="a3", bs=[B(col2="b3")]) + ], + alist) + + for a in alist: + sess.delete(a) + sess.flush() + sess.close() + clear_mappers() + + metadata.create_all() + try: + go() + finally: + metadata.drop_all() + assert_no_mappers() + + def test_with_inheritance(self): + metadata = MetaData(testing.db) + + table1 = Table("mytable", metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)) + ) + + table2 = Table("mytable2", metadata, + Column('col1', Integer, ForeignKey('mytable.col1'), primary_key=True), + Column('col3', String(30)), + ) + + @profile_memory + def go(): + class A(Base): + pass + class B(A): + pass + + mapper(A, table1, polymorphic_on=table1.c.col2, polymorphic_identity='a') + mapper(B, table2, inherits=A, polymorphic_identity='b') + + sess = create_session() + a1 = A() + a2 = A() + b1 = B(col3='b1') + b2 = B(col3='b2') + for x in [a1,a2,b1, b2]: + sess.save(x) + sess.flush() + sess.clear() + + alist = sess.query(A).all() + self.assertEquals( + [ + A(), A(), B(col3='b1'), B(col3='b2') + ], + alist) + + for a in alist: + sess.delete(a) + sess.flush() + + # dont need to clear_mappers() + del B + del A + + metadata.create_all() + try: + go() + finally: + metadata.drop_all() + assert_no_mappers() + + def test_with_manytomany(self): + metadata = MetaData(testing.db) + + table1 = Table("mytable", metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)) + ) + + table2 = Table("mytable2", metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)), + ) + + table3 = Table('t1tot2', metadata, + Column('t1', Integer, ForeignKey('mytable.col1')), + Column('t2', Integer, ForeignKey('mytable2.col1')), + ) + + @profile_memory + def go(): + class A(Base): + pass + class B(Base): + pass + + mapper(A, table1, properties={ + 'bs':relation(B, secondary=table3, backref='as') + }) + mapper(B, table2) + + sess = create_session() + a1 = A(col2='a1') + a2 = A(col2='a2') + b1 = B(col2='b1') + b2 = B(col2='b2') + a1.bs.append(b1) + a2.bs.append(b2) + for x in [a1,a2]: + sess.save(x) + sess.flush() + sess.clear() + + alist = sess.query(A).all() + self.assertEquals( + [ + A(bs=[B(col2='b1')]), A(bs=[B(col2='b2')]) + ], + alist) + + for a in alist: + sess.delete(a) + sess.flush() + + # dont need to clear_mappers() + del B + del A + + metadata.create_all() + try: + go() + finally: + metadata.drop_all() + assert_no_mappers() + + if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/orm/merge.py b/test/orm/merge.py index 3dd0a95a47..fd61ccc28c 100644 --- a/test/orm/merge.py +++ b/test/orm/merge.py @@ -1,117 +1,271 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions from sqlalchemy.orm import * +from sqlalchemy.orm import mapperlib +from sqlalchemy.util import OrderedSet from testlib import * +from testlib import fixtures from testlib.tables import * import testlib.tables as tables -class MergeTest(AssertMixin): +class MergeTest(TestBase, AssertsExecutionResults): """tests session.merge() functionality""" def setUpAll(self): tables.create() + def tearDownAll(self): tables.drop() + def tearDown(self): clear_mappers() tables.delete() - def setUp(self): - pass - - def test_unsaved(self): - """test merge of a single transient entity.""" + + def test_transient_to_pending(self): + class User(fixtures.Base): + pass mapper(User, users) sess = create_session() - - u = User() - u.user_id = 7 - u.user_name = "fred" + + u = User(user_id=7, user_name='fred') u2 = sess.merge(u) assert u2 in sess - assert u2.user_id == 7 - assert u2.user_name == 'fred' + self.assertEquals(u2, User(user_id=7, user_name='fred')) + sess.flush() + sess.clear() + self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred')) + + def test_transient_to_pending_collection(self): + class User(fixtures.Base): + pass + class Address(fixtures.Base): + pass + mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)}) + mapper(Address, addresses) + + u = User(user_id=7, user_name='fred', addresses=OrderedSet([ + Address(address_id=1, email_address='fred1'), + Address(address_id=2, email_address='fred2'), + ])) + sess = create_session() + sess.merge(u) sess.flush() sess.clear() - u2 = sess.query(User).get(7) - assert u2.user_name == 'fred' + self.assertEquals(sess.query(User).one(), + User(user_id=7, user_name='fred', addresses=OrderedSet([ + Address(address_id=1, email_address='fred1'), + Address(address_id=2, email_address='fred2'), + ])) + ) + + def test_transient_to_persistent(self): + class User(fixtures.Base): + pass + mapper(User, users) + sess = create_session() + u = User(user_id=7, user_name='fred') + sess.save(u) + sess.flush() + sess.clear() + + u2 = User(user_id=7, user_name='fred jones') + u2 = sess.merge(u2) + sess.flush() + sess.clear() + self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred jones')) + + def test_transient_to_persistent_collection(self): + class User(fixtures.Base): + pass + class Address(fixtures.Base): + pass + mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)}) + mapper(Address, addresses) + + u = User(user_id=7, user_name='fred', addresses=OrderedSet([ + Address(address_id=1, email_address='fred1'), + Address(address_id=2, email_address='fred2'), + ])) + sess = create_session() + sess.save(u) + sess.flush() + sess.clear() + + u = User(user_id=7, user_name='fred', addresses=OrderedSet([ + Address(address_id=3, email_address='fred3'), + Address(address_id=4, email_address='fred4'), + ])) + + u = sess.merge(u) + self.assertEquals(u, + User(user_id=7, user_name='fred', addresses=OrderedSet([ + Address(address_id=3, email_address='fred3'), + Address(address_id=4, email_address='fred4'), + ])) + ) + sess.flush() + sess.clear() + self.assertEquals(sess.query(User).one(), + User(user_id=7, user_name='fred', addresses=OrderedSet([ + Address(address_id=3, email_address='fred3'), + Address(address_id=4, email_address='fred4'), + ])) + ) + + def test_detached_to_persistent_collection(self): + class User(fixtures.Base): + pass + class Address(fixtures.Base): + pass + mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)}) + mapper(Address, addresses) + + a = Address(address_id=1, email_address='fred1') + u = User(user_id=7, user_name='fred', addresses=OrderedSet([ + a, + Address(address_id=2, email_address='fred2'), + ])) + sess = create_session() + sess.save(u) + sess.flush() + sess.clear() + + u.user_name='fred jones' + u.addresses.add(Address(address_id=3, email_address='fred3')) + u.addresses.remove(a) + + u = sess.merge(u) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(User).first(), + User(user_id=7, user_name='fred jones', addresses=OrderedSet([ + Address(address_id=2, email_address='fred2'), + Address(address_id=3, email_address='fred3'), + ])) + ) + def test_unsaved_cascade(self): - """test merge of a transient entity with two child transient entities.""" + """test merge of a transient entity with two child transient entities, with a bidirectional relation.""" + + class User(fixtures.Base): + pass + class Address(fixtures.Base): + pass + mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), cascade="all") + 'addresses':relation(mapper(Address, addresses), cascade="all", backref="user") }) sess = create_session() - u = User() - u.user_id = 7 - u.user_name = "fred" - a1 = Address() - a1.email_address='foo@bar.com' - a2 = Address() - a2.email_address = 'hoho@la.com' + u = User(user_id=7, user_name='fred') + a1 = Address(email_address='foo@bar.com') + a2 = Address(email_address='hoho@bar.com') u.addresses.append(a1) u.addresses.append(a2) - + u2 = sess.merge(u) - self.assert_result([u], User, {'user_id':7, 'user_name':'fred', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@la.com'}])}) - self.assert_result([u2], User, {'user_id':7, 'user_name':'fred', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@la.com'}])}) + self.assertEquals(u, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')])) + self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')])) sess.flush() sess.clear() u2 = sess.query(User).get(7) - self.assert_result([u2], User, {'user_id':7, 'user_name':'fred', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@la.com'}])}) + self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')])) - def test_saved_cascade(self): + def test_attribute_cascade(self): """test merge of a persistent entity with two child persistent entities.""" + + class User(fixtures.Base): + pass + class Address(fixtures.Base): + pass + mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses), backref='user') }) sess = create_session() - + # set up data and save - u = User() - u.user_id = 7 - u.user_name = "fred" - a1 = Address() - a1.email_address='foo@bar.com' - a2 = Address() - a2.email_address = 'hoho@la.com' - u.addresses.append(a1) - u.addresses.append(a2) + u = User(user_id=7, user_name='fred', addresses=[ + Address(email_address='foo@bar.com'), + Address(email_address = 'hoho@la.com') + ]) sess.save(u) sess.flush() # assert data was saved sess2 = create_session() u2 = sess2.query(User).get(7) - self.assert_result([u2], User, {'user_id':7, 'user_name':'fred', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@la.com'}])}) - + self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@la.com')])) + # make local changes to data u.user_name = 'fred2' u.addresses[1].email_address = 'hoho@lalala.com' - + # new session, merge modified data into session sess3 = create_session() u3 = sess3.merge(u) - # insure local changes are pending - self.assert_result([u3], User, {'user_id':7, 'user_name':'fred2', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@lalala.com'}])}) + + # ensure local changes are pending + self.assertEquals(u3, User(user_id=7, user_name='fred2', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@lalala.com')])) # save merged data sess3.flush() - + # assert modified/merged data was saved sess.clear() u = sess.query(User).get(7) - self.assert_result([u], User, {'user_id':7, 'user_name':'fred2', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@lalala.com'}])}) + self.assertEquals(u, User(user_id=7, user_name='fred2', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@lalala.com')])) + + # merge persistent object into another session + sess4 = create_session() + u = sess4.merge(u) + assert len(u.addresses) + for a in u.addresses: + assert a.user is u + def go(): + sess4.flush() + # no changes; therefore flush should do nothing + self.assert_sql_count(testing.db, go, 0) + + # test with "dontload" merge + sess5 = create_session() + u = sess5.merge(u, dont_load=True) + assert len(u.addresses) + for a in u.addresses: + assert a.user is u + def go(): + sess5.flush() + # no changes; therefore flush should do nothing + # but also, dont_load wipes out any difference in committed state, + # so no flush at all + self.assert_sql_count(testing.db, go, 0) + + sess4 = create_session() + u = sess4.merge(u, dont_load=True) + # post merge change + u.addresses[1].email_address='afafds' + def go(): + sess4.flush() + # afafds change flushes + self.assert_sql_count(testing.db, go, 1) + + sess5 = create_session() + u2 = sess5.query(User).get(u.user_id) + assert u2.user_name == 'fred2' + assert u2.addresses[1].email_address == 'afafds' + + def test_one_to_many_cascade(self): - def test_saved_cascade_2(self): - """tests a more involved merge""" mapper(Order, orders, properties={ 'items':relation(mapper(Item, orderitems)) }) - + mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses)), 'orders':relation(Order, backref='customer') }) - + sess = create_session() u = User() u.user_name='fred' @@ -124,10 +278,10 @@ class MergeTest(AssertMixin): o.items.append(i1) o.items.append(i2) u.orders.append(o) - + sess.save(u) sess.flush() - + sess2 = create_session() u2 = sess2.query(User).get(u.user_id) u.orders[0].items[1].item_name = 'item 2 modified' @@ -139,29 +293,246 @@ class MergeTest(AssertMixin): o.customer.user_name = 'also fred' sess2.merge(o) assert o2.customer.user_name == 'also fred' + + def test_one_to_one_cascade(self): + + mapper(User, users, properties={ + 'address':relation(mapper(Address, addresses),uselist = False) + }) + sess = create_session() + u = User() + u.user_id = 7 + u.user_name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.address = a1 + + sess.save(u) + sess.flush() + + sess2 = create_session() + u2 = sess2.query(User).get(7) + u2.user_name = 'fred2' + u2.address.email_address = 'hoho@lalala.com' + + u3 = sess.merge(u2) + + def test_transient_dontload(self): + mapper(User, users) + + sess = create_session() + u = User() + self.assertRaisesMessage(exceptions.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) + + + def test_dontload_with_backrefs(self): + """test that dontload populates relations in both directions without requiring a load""" + + class User(fixtures.Base): + pass + class Address(fixtures.Base): + pass + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), backref='user') + }) + + u = User(user_id=7, user_name='fred', addresses=[Address(email_address='ad1'), Address(email_address='ad2')]) + sess = create_session() + sess.save(u) + sess.flush() + sess.close() + assert 'user' in u.addresses[1].__dict__ + + sess = create_session() + u2 = sess.merge(u, dont_load=True) + assert 'user' in u2.addresses[1].__dict__ + self.assertEquals(u2.addresses[1].user, User(user_id=7, user_name='fred')) + + sess.expire(u2.addresses[1], ['user']) + assert 'user' not in u2.addresses[1].__dict__ + sess.close() + + sess = create_session() + u = sess.merge(u2, dont_load=True) + assert 'user' not in u.addresses[1].__dict__ + self.assertEquals(u.addresses[1].user, User(user_id=7, user_name='fred')) - def test_saved_cascade_3(self): - """test merge of a persistent entity with one_to_one relationship""" - mapper(User, users, properties={ - 'address':relation(mapper(Address, addresses),uselist = False) - }) - sess = create_session() - u = User() - u.user_id = 7 - u.user_name = "fred" - a1 = Address() - a1.email_address='foo@bar.com' - u.address = a1 - - sess.save(u) - sess.flush() - - sess2 = create_session() - u2 = sess2.query(User).get(7) - u2.user_name = 'fred2' - u2.address.email_address = 'hoho@lalala.com' - - u3 = sess.merge(u2) - -if __name__ == "__main__": - testbase.main() + + def test_dontload_with_eager(self): + """this test illustrates that with dont_load=True, we can't just + copy the committed_state of the merged instance over; since it references collection objects + which themselves are to be merged. This committed_state would instead need to be piecemeal + 'converted' to represent the correct objects. + However, at the moment I'd rather not support this use case; if you are merging with dont_load=True, + you're typically dealing with caching and the merged objects shouldnt be "dirty". + """ + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses)) + }) + sess = create_session() + u = User() + u.user_id = 7 + u.user_name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.addresses.append(a1) + + sess.save(u) + sess.flush() + + sess2 = create_session() + u2 = sess2.query(User).options(eagerload('addresses')).get(7) + + sess3 = create_session() + u3 = sess3.merge(u2, dont_load=True) + def go(): + sess3.flush() + self.assert_sql_count(testing.db, go, 0) + + def test_dont_load_disallows_dirty(self): + """dont_load doesnt support 'dirty' objects right now (see test_dont_load_with_eager()). + Therefore lets assert it.""" + + mapper(User, users) + sess = create_session() + u = User() + u.user_id = 7 + u.user_name = "fred" + sess.save(u) + sess.flush() + + u.user_name = 'ed' + sess2 = create_session() + try: + sess2.merge(u, dont_load=True) + assert False + except exceptions.InvalidRequestError, e: + assert "merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True." in str(e) + + u2 = sess2.query(User).get(7) + + sess3 = create_session() + u3 = sess3.merge(u2, dont_load=True) + assert not sess3.dirty + def go(): + sess3.flush() + self.assert_sql_count(testing.db, go, 0) + + def test_dont_load_sets_entityname(self): + """test that a dont_load-merged entity has entity_name set, has_mapper() passes, and lazyloads work""" + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses),uselist = True) + }) + sess = create_session() + u = User() + u.user_id = 7 + u.user_name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.addresses.append(a1) + + sess.save(u) + sess.flush() + sess.clear() + + # reload 'u' such that its addresses list hasn't loaded + u = sess.query(User).get(7) + + sess2 = create_session() + u2 = sess2.merge(u, dont_load=True) + assert not sess2.dirty + # assert merged instance has a mapper and lazy load proceeds + assert hasattr(u2, '_entity_name') + assert mapperlib.has_mapper(u2) + def go(): + assert u2.addresses != [] + assert len(u2.addresses) == 1 + self.assert_sql_count(testing.db, go, 1) + + def test_dont_load_sets_backrefs(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses),backref='user') + }) + sess = create_session() + u = User() + u.user_id = 7 + u.user_name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.addresses.append(a1) + + sess.save(u) + sess.flush() + + assert u.addresses[0].user is u + + sess2 = create_session() + u2 = sess2.merge(u, dont_load=True) + assert not sess2.dirty + def go(): + assert u2.addresses[0].user is u2 + self.assert_sql_count(testing.db, go, 0) + + def test_dont_load_preserves_parents(self): + """test that merge with dont_load does not trigger a 'delete-orphan' operation. + + merge with dont_load sets attributes without using events. this means the + 'hasparent' flag is not propagated to the newly merged instance. in fact this + works out OK, because the '_state.parents' collection on the newly + merged instance is empty; since the mapper doesn't see an active 'False' setting + in this collection when _is_orphan() is called, it does not count as an orphan + (i.e. this is the 'optimistic' logic in mapper._is_orphan().) + """ + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses),backref='user', cascade="all, delete-orphan") + }) + sess = create_session() + u = User() + u.user_id = 7 + u.user_name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.addresses.append(a1) + sess.save(u) + sess.flush() + + assert u.addresses[0].user is u + + sess2 = create_session() + u2 = sess2.merge(u, dont_load=True) + assert not sess2.dirty + a2 = u2.addresses[0] + a2.email_address='somenewaddress' + assert not object_mapper(a2)._is_orphan(a2) + sess2.flush() + sess2.clear() + assert sess2.query(User).get(u2.user_id).addresses[0].email_address == 'somenewaddress' + + # this use case is not supported; this is with a pending Address on the pre-merged + # object, and we currently dont support 'dirty' objects being merged with dont_load=True. + # in this case, the empty '_state.parents' collection would be an issue, + # since the optimistic flag is False in _is_orphan() for pending instances. + # so if we start supporting 'dirty' with dont_load=True, this test will need to pass + sess = create_session() + u = sess.query(User).get(7) + u.addresses.append(Address()) + sess2 = create_session() + try: + u2 = sess2.merge(u, dont_load=True) + assert False + + # if dont_load is changed to support dirty objects, this code needs to pass + a2 = u2.addresses[0] + a2.email_address='somenewaddress' + assert not object_mapper(a2)._is_orphan(a2) + sess2.flush() + sess2.clear() + assert sess2.query(User).get(u2.user_id).addresses[0].email_address == 'somenewaddress' + except exceptions.InvalidRequestError, e: + assert "dont_load=True option does not support" in str(e) + + +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/naturalpks.py b/test/orm/naturalpks.py new file mode 100644 index 0000000000..ec7d2fca99 --- /dev/null +++ b/test/orm/naturalpks.py @@ -0,0 +1,378 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy import exceptions + +from testlib.fixtures import * +from testlib import * + +"""test primary key changing capabilities and passive/non-passive cascading updates.""" + +class NaturalPKTest(ORMTest): + def define_tables(self, metadata): + global users, addresses, items, users_to_items + + users = Table('users', metadata, + Column('username', String(50), primary_key=True), + Column('fullname', String(100))) + + addresses = Table('addresses', metadata, + Column('email', String(50), primary_key=True), + Column('username', String(50), ForeignKey('users.username', onupdate="cascade"))) + + items = Table('items', metadata, + Column('itemname', String(50), primary_key=True), + Column('description', String(100))) + + users_to_items = Table('userstoitems', metadata, + Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True), + Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True), + ) + + def test_entity(self): + mapper(User, users) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.save(u1) + sess.flush() + assert sess.get(User, 'jack') is u1 + + u1.username = 'ed' + sess.flush() + + def go(): + assert sess.get(User, 'ed') is u1 + self.assert_sql_count(testing.db, go, 0) + + assert sess.get(User, 'jack') is None + + sess.clear() + u1 = sess.query(User).get('ed') + self.assertEquals(User(username='ed', fullname='jack'), u1) + + def test_expiry(self): + mapper(User, users) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.save(u1) + sess.flush() + assert sess.get(User, 'jack') is u1 + + users.update(values={u1.c.username:'jack'}).execute(username='ed') + + try: + # expire/refresh works off of primary key. the PK is gone + # in this case so theres no way to look it up. criterion- + # based session invalidation could solve this [ticket:911] + sess.expire(u1) + u1.username + assert False + except exceptions.InvalidRequestError, e: + assert "Could not refresh instance" in str(e) + + sess.clear() + assert sess.get(User, 'jack') is None + assert sess.get(User, 'ed').fullname == 'jack' + + @testing.unsupported('sqlite','mysql') + def test_onetomany_passive(self): + self._test_onetomany(True) + + def test_onetomany_nonpassive(self): + self._test_onetomany(False) + + def _test_onetomany(self, passive_updates): + mapper(User, users, properties={ + 'addresses':relation(Address, passive_updates=passive_updates) + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + u1.addresses.append(Address(email='jack1')) + u1.addresses.append(Address(email='jack2')) + sess.save(u1) + sess.flush() + + assert sess.get(Address, 'jack1') is u1.addresses[0] + + u1.username = 'ed' + sess.flush() + assert u1.addresses[0].username == 'ed' + + sess.clear() + self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + + u1 = sess.get(User, 'ed') + u1.username = 'jack' + def go(): + sess.flush() + if not passive_updates: + self.assert_sql_count(testing.db, go, 4) # test passive_updates=False; load addresses, update user, update 2 addresses + else: + self.assert_sql_count(testing.db, go, 1) # test passive_updates=True; update user + sess.clear() + assert User(username='jack', addresses=[Address(username='jack'), Address(username='jack')]) == sess.get(User, 'jack') + + u1 = sess.get(User, 'jack') + u1.addresses = [] + u1.username = 'fred' + sess.flush() + sess.clear() + assert sess.get(Address, 'jack1').username is None + u1 = sess.get(User, 'fred') + self.assertEquals(User(username='fred', fullname='jack'), u1) + + @testing.unsupported('sqlite', 'mysql') + def test_manytoone_passive(self): + self._test_manytoone(True) + + def test_manytoone_nonpassive(self): + self._test_manytoone(False) + + def _test_manytoone(self, passive_updates): + mapper(User, users) + mapper(Address, addresses, properties={ + 'user':relation(User, passive_updates=passive_updates) + }) + + sess = create_session() + a1 = Address(email='jack1') + a2 = Address(email='jack2') + + u1 = User(username='jack', fullname='jack') + a1.user = u1 + a2.user = u1 + sess.save(a1) + sess.save(a2) + sess.flush() + + u1.username = 'ed' + + print id(a1), id(a2), id(u1) + print u1._state.parents + def go(): + sess.flush() + if passive_updates: + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 3) + + def go(): + sess.flush() + self.assert_sql_count(testing.db, go, 0) + + assert a1.username == a2.username == 'ed' + sess.clear() + self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + + @testing.unsupported('sqlite', 'mysql') + def test_bidirectional_passive(self): + self._test_bidirectional(True) + + def test_bidirectional_nonpassive(self): + self._test_bidirectional(False) + + def _test_bidirectional(self, passive_updates): + mapper(User, users) + mapper(Address, addresses, properties={ + 'user':relation(User, passive_updates=passive_updates, backref='addresses') + }) + + sess = create_session() + a1 = Address(email='jack1') + a2 = Address(email='jack2') + + u1 = User(username='jack', fullname='jack') + a1.user = u1 + a2.user = u1 + sess.save(a1) + sess.save(a2) + sess.flush() + + u1.username = 'ed' + (ad1, ad2) = sess.query(Address).all() + self.assertEquals([Address(username='jack'), Address(username='jack')], [ad1, ad2]) + def go(): + sess.flush() + if passive_updates: + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 3) + self.assertEquals([Address(username='ed'), Address(username='ed')], [ad1, ad2]) + sess.clear() + self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + + u1 = sess.get(User, 'ed') + assert len(u1.addresses) == 2 # load addresses + u1.username = 'fred' + print "--------------------------------" + def go(): + sess.flush() + # check that the passive_updates is on on the other side + if passive_updates: + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 3) + sess.clear() + self.assertEquals([Address(username='fred'), Address(username='fred')], sess.query(Address).all()) + + + @testing.unsupported('sqlite', 'mysql') + def test_manytomany_passive(self): + self._test_manytomany(True) + + def test_manytomany_nonpassive(self): + self._test_manytomany(False) + + def _test_manytomany(self, passive_updates): + mapper(User, users, properties={ + 'items':relation(Item, secondary=users_to_items, backref='users', passive_updates=passive_updates) + }) + mapper(Item, items) + + sess = create_session() + u1 = User(username='jack') + u2 = User(username='fred') + i1 = Item(itemname='item1') + i2 = Item(itemname='item2') + + u1.items.append(i1) + u1.items.append(i2) + i2.users.append(u2) + sess.save(u1) + sess.save(u2) + sess.flush() + + r = sess.query(Item).all() + # fixtures.Base can't handle a comparison with the backrefs involved.... + self.assertEquals(Item(itemname='item1'), r[0]) + self.assertEquals(['jack'], [u.username for u in r[0].users]) + self.assertEquals(Item(itemname='item2'), r[1]) + self.assertEquals(['jack', 'fred'], [u.username for u in r[1].users]) + + u2.username='ed' + def go(): + sess.flush() + go() + def go(): + sess.flush() + self.assert_sql_count(testing.db, go, 0) + + sess.clear() + r = sess.query(Item).all() + self.assertEquals(Item(itemname='item1'), r[0]) + self.assertEquals(['jack'], [u.username for u in r[0].users]) + self.assertEquals(Item(itemname='item2'), r[1]) + self.assertEquals(['ed', 'jack'], sorted([u.username for u in r[1].users])) + +class SelfRefTest(ORMTest): + def define_tables(self, metadata): + global nodes, Node + + nodes = Table('nodes', metadata, + Column('name', String(50), primary_key=True), + Column('parent', String(50), ForeignKey('nodes.name', onupdate='cascade')) + ) + + class Node(Base): + pass + + def test_onetomany(self): + mapper(Node, nodes, properties={ + 'children':relation(Node, backref=backref('parentnode', remote_side=nodes.c.name, passive_updates=False), passive_updates=False) + }) + + sess = create_session() + n1 = Node(name='n1') + n1.children.append(Node(name='n11')) + n1.children.append(Node(name='n12')) + n1.children.append(Node(name='n13')) + sess.save(n1) + sess.flush() + + n1.name = 'new n1' + sess.flush() + self.assertEquals(n1.children[1].parent, 'new n1') + self.assertEquals(['new n1', 'new n1', 'new n1'], [n.parent for n in sess.query(Node).filter(Node.name.in_(['n11', 'n12', 'n13']))]) + + +class NonPKCascadeTest(ORMTest): + def define_tables(self, metadata): + global users, addresses + + users = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('username', String(50), unique=True), + Column('fullname', String(100))) + + addresses = Table('addresses', metadata, + Column('id', Integer, primary_key=True), + Column('email', String(50)), + Column('username', String(50), ForeignKey('users.username', onupdate="cascade"))) + + @testing.unsupported('sqlite','mysql') + def test_onetomany_passive(self): + self._test_onetomany(True) + + def test_onetomany_nonpassive(self): + self._test_onetomany(False) + + def _test_onetomany(self, passive_updates): + mapper(User, users, properties={ + 'addresses':relation(Address, passive_updates=passive_updates) + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + u1.addresses.append(Address(email='jack1')) + u1.addresses.append(Address(email='jack2')) + sess.save(u1) + sess.flush() + a1 = u1.addresses[0] + + self.assertEquals(select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)]) + + assert sess.get(Address, a1.id) is u1.addresses[0] + + u1.username = 'ed' + sess.flush() + assert u1.addresses[0].username == 'ed' + self.assertEquals(select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)]) + + sess.clear() + self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + + u1 = sess.get(User, u1.id) + u1.username = 'jack' + def go(): + sess.flush() + if not passive_updates: + self.assert_sql_count(testing.db, go, 4) # test passive_updates=False; load addresses, update user, update 2 addresses + else: + self.assert_sql_count(testing.db, go, 1) # test passive_updates=True; update user + sess.clear() + assert User(username='jack', addresses=[Address(username='jack'), Address(username='jack')]) == sess.get(User, u1.id) + sess.clear() + + u1 = sess.get(User, u1.id) + u1.addresses = [] + u1.username = 'fred' + sess.flush() + sess.clear() + a1 = sess.get(Address, a1.id) + self.assertEquals(a1.username, None) + + self.assertEquals(select([addresses.c.username]).execute().fetchall(), [(None,), (None,)]) + + u1 = sess.get(User, u1.id) + self.assertEquals(User(username='fred', fullname='jack'), u1) + + +if __name__ == '__main__': + testenv.main() diff --git a/test/orm/onetoone.py b/test/orm/onetoone.py index e41fa1d20f..ae0d6ef86d 100644 --- a/test/orm/onetoone.py +++ b/test/orm/onetoone.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext @@ -24,12 +24,13 @@ class Port(object): self.name=name self.description = description -class O2OTest(AssertMixin): +class O2OTest(TestBase, AssertsExecutionResults): + @testing.uses_deprecated('SessionContext') def setUpAll(self): global jack, port, metadata, ctx - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) ctx = SessionContext(create_session) - jack = Table('jack', metadata, + jack = Table('jack', metadata, Column('id', Integer, primary_key=True), #Column('room_id', Integer, ForeignKey("room.id")), Column('number', String(50)), @@ -38,7 +39,7 @@ class O2OTest(AssertMixin): ) - port = Table('port', metadata, + port = Table('port', metadata, Column('id', Integer, primary_key=True), #Column('device_id', Integer, ForeignKey("device.id")), Column('name', String(30)), @@ -52,12 +53,13 @@ class O2OTest(AssertMixin): clear_mappers() def tearDownAll(self): metadata.drop_all() - + + @testing.uses_deprecated('SessionContext') def test1(self): mapper(Port, port, extension=ctx.mapper_extension) mapper(Jack, jack, order_by=[jack.c.number],properties = { 'port': relation(Port, backref='jack', uselist=False, lazy=True), - }, extension=ctx.mapper_extension) + }, extension=ctx.mapper_extension) j=Jack(number='101') p=Port(name='fa0/1') @@ -87,5 +89,5 @@ class O2OTest(AssertMixin): ctx.current.delete(j) ctx.current.flush() -if __name__ == "__main__": - testbase.main() +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/pickled.py b/test/orm/pickled.py new file mode 100644 index 0000000000..84f5e5dafb --- /dev/null +++ b/test/orm/pickled.py @@ -0,0 +1,134 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy import exceptions +from sqlalchemy.orm import * +from testlib import * +from testlib.fixtures import * +import pickle + +class EmailUser(User): + pass + +class PickleTest(FixtureTest): + keep_mappers = False + keep_data = False + + def test_transient(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref="user") + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(name='ed') + u1.addresses.append(Address(email_address='ed@bar.com')) + + u2 = pickle.loads(pickle.dumps(u1)) + sess.save(u2) + sess.flush() + + sess.clear() + + self.assertEquals(u1, sess.query(User).get(u2.id)) + + def test_class_deferred_cols(self): + mapper(User, users, properties={ + 'name':deferred(users.c.name), + 'addresses':relation(Address, backref="user") + }) + mapper(Address, addresses, properties={ + 'email_address':deferred(addresses.c.email_address) + }) + sess = create_session() + u1 = User(name='ed') + u1.addresses.append(Address(email_address='ed@bar.com')) + sess.save(u1) + sess.flush() + sess.clear() + u1 = sess.query(User).get(u1.id) + assert 'name' not in u1.__dict__ + assert 'addresses' not in u1.__dict__ + + u2 = pickle.loads(pickle.dumps(u1)) + sess2 = create_session() + sess2.update(u2) + self.assertEquals(u2.name, 'ed') + self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + + u2 = pickle.loads(pickle.dumps(u1)) + sess2 = create_session() + u2 = sess2.merge(u2, dont_load=True) + self.assertEquals(u2.name, 'ed') + self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + + def test_instance_deferred_cols(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref="user") + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(name='ed') + u1.addresses.append(Address(email_address='ed@bar.com')) + sess.save(u1) + sess.flush() + sess.clear() + + u1 = sess.query(User).options(defer('name'), defer('addresses.email_address')).get(u1.id) + assert 'name' not in u1.__dict__ + assert 'addresses' not in u1.__dict__ + + u2 = pickle.loads(pickle.dumps(u1)) + sess2 = create_session() + sess2.update(u2) + self.assertEquals(u2.name, 'ed') + assert 'addresses' not in u2.__dict__ + ad = u2.addresses[0] + assert 'email_address' not in ad.__dict__ + self.assertEquals(ad.email_address, 'ed@bar.com') + self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + + u2 = pickle.loads(pickle.dumps(u1)) + sess2 = create_session() + u2 = sess2.merge(u2, dont_load=True) + self.assertEquals(u2.name, 'ed') + assert 'addresses' not in u2.__dict__ + ad = u2.addresses[0] + assert 'email_address' in ad.__dict__ # mapper options dont transmit over merge() right now + self.assertEquals(ad.email_address, 'ed@bar.com') + self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + + +class PolymorphicDeferredTest(ORMTest): + def define_tables(self, metadata): + global users, email_users + users = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30)), + Column('type', String(30)), + ) + email_users = Table('email_users', metadata, + Column('id', Integer, ForeignKey('users.id'), primary_key=True), + Column('email_address', String(30)) + ) + + def test_polymorphic_deferred(self): + mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type, polymorphic_fetch='deferred') + mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser') + + eu = EmailUser(name="user1", email_address='foo@bar.com') + sess = create_session() + sess.save(eu) + sess.flush() + sess.clear() + + eu = sess.query(User).first() + eu2 = pickle.loads(pickle.dumps(eu)) + sess2 = create_session() + sess2.update(eu2) + assert 'email_address' not in eu2.__dict__ + self.assertEquals(eu2.email_address, 'foo@bar.com') + + +if __name__ == '__main__': + testenv.main() diff --git a/test/orm/query.py b/test/orm/query.py index 3783e1fa0c..f1afdb90b4 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1,35 +1,30 @@ -import testbase +import testenv; testenv.configure_for_tests() import operator from sqlalchemy import * -from sqlalchemy import ansisql +from sqlalchemy import exceptions, util +from sqlalchemy.sql import compiler +from sqlalchemy.engine import default from sqlalchemy.orm import * + from testlib import * -from fixtures import * +from testlib import engines +from testlib.fixtures import * + +from sqlalchemy.orm.util import _join as join, _outerjoin as outerjoin -class QueryTest(ORMTest): +class QueryTest(FixtureTest): keep_mappers = True keep_data = True - - def setUpAll(self): - super(QueryTest, self).setUpAll() - install_fixture_data() - self.setup_mappers() - - def tearDownAll(self): - clear_mappers() - super(QueryTest, self).tearDownAll() - - def define_tables(self, meta): - # a slight dirty trick here. - meta.tables = metadata.tables - metadata.connect(meta.bind) - + def setup_mappers(self): mapper(User, users, properties={ 'addresses':relation(Address, backref='user'), 'orders':relation(Order, backref='user'), # o2m, m2o }) - mapper(Address, addresses) + mapper(Address, addresses, properties={ + 'dingaling':relation(Dingaling, uselist=False, backref="address") #o2o + }) + mapper(Dingaling, dingalings) mapper(Order, orders, properties={ 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m 'address':relation(Address), # m2o @@ -39,6 +34,24 @@ class QueryTest(ORMTest): }) mapper(Keyword, keywords) +class UnicodeSchemaTest(QueryTest): + keep_mappers = False + + def setup_mappers(self): + pass + + def define_tables(self, metadata): + super(UnicodeSchemaTest, self).define_tables(metadata) + global uni_meta, uni_users + uni_meta = MetaData() + uni_users = Table(u'users', uni_meta, + Column(u'id', Integer, primary_key=True), + Column(u'name', String(30), nullable=False)) + + def test_get(self): + mapper(User, uni_users) + assert User(id=7) == create_session(bind=testing.db).query(User).get(7) + class GetTest(QueryTest): def test_get(self): s = create_session() @@ -50,46 +63,103 @@ class GetTest(QueryTest): u2 = s.query(User).get(7) assert u is not u2 - def test_load(self): + def test_no_criterion(self): + """test that get()/load() does not use preexisting filter/etc. criterion""" + s = create_session() + try: + s.query(User).join('addresses').filter(Address.user_id==8).get(7) + assert False + except exceptions.SAWarning, e: + assert str(e) == "Query.get() being called on a Query with existing criterion; criterion is being ignored." + + @testing.emits_warning('Query.*') + def warns(): + assert s.query(User).filter(User.id==7).get(19) is None + + u = s.query(User).get(7) + assert s.query(User).filter(User.id==9).get(7) is u + s.clear() + assert s.query(User).filter(User.id==9).get(7).id == u.id + + # user 10 has no addresses + u = s.query(User).get(10) + assert s.query(User).join('addresses').get(10) is u + s.clear() + assert s.query(User).join('addresses').get(10).id == u.id + + u = s.query(User).get(7) + assert s.query(User).join('addresses').filter(Address.user_id==8).filter(User.id==7).first() is None + assert s.query(User).join('addresses').filter(Address.user_id==8).get(7) is u + s.clear() + assert s.query(User).join('addresses').filter(Address.user_id==8).get(7).id == u.id + + assert s.query(User).join('addresses').filter(Address.user_id==8).load(7).id == u.id + warns() + + def test_unique_param_names(self): + class SomeUser(object): + pass + s = users.select(users.c.id!=12).alias('users') + m = mapper(SomeUser, s) + print s.primary_key + print m.primary_key + assert s.primary_key == m.primary_key + + row = s.select(use_labels=True).execute().fetchone() + print row[s.primary_key[0]] + + sess = create_session() + assert sess.query(SomeUser).get(7).name == 'jack' + + def test_load(self): + s = create_session() + try: assert s.query(User).load(19) is None assert False except exceptions.InvalidRequestError: assert True - + u = s.query(User).load(7) u2 = s.query(User).load(7) assert u is u2 s.clear() u2 = s.query(User).load(7) assert u is not u2 - + u2.name = 'some name' - a = Address(name='some other name') + a = Address(email_address='some other name') u2.addresses.append(a) assert u2 in s.dirty assert a in u2.addresses - + s.query(User).load(7) assert u2 not in s.dirty assert u2.name =='jack' assert a not in u2.addresses - + + @testing.exclude('mysql', '<', (4, 1)) def test_unicode(self): - """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail + """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail on postgres, mysql and oracle unless it is converted to an encoded string""" - - table = Table('unicode_data', users.metadata, + + metadata = MetaData(engines.utf8_engine()) + table = Table('unicode_data', metadata, Column('id', Unicode(40), primary_key=True), Column('data', Unicode(40))) - table.create() - ustring = 'petit voix m\xe2\x80\x99a '.decode('utf-8') - table.insert().execute(id=ustring, data=ustring) - class LocalFoo(Base):pass - mapper(LocalFoo, table) - assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring) + try: + metadata.create_all() + ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8') + table.insert().execute(id=ustring, data=ustring) + class LocalFoo(Base): + pass + mapper(LocalFoo, table) + self.assertEquals(create_session().query(LocalFoo).get(ustring), + LocalFoo(id=ustring, data=ustring)) + finally: + metadata.drop_all() def test_populate_existing(self): s = create_session() @@ -117,33 +187,33 @@ class GetTest(QueryTest): s.query(User).populate_existing().all() assert u.addresses[0].email_address == 'lala' assert u.orders[1].items[2].description == 'item 12' - + # eager load does s.query(User).options(eagerload('addresses'), eagerload_all('orders.items')).populate_existing().all() assert u.addresses[0].email_address == 'jack@bean.com' assert u.orders[1].items[2].description == 'item 5' - + class OperatorTest(QueryTest): """test sql.Comparator implementation for MapperProperties""" - + def _test(self, clause, expected): - c = str(clause.compile(dialect=ansisql.ANSIDialect())) + c = str(clause.compile(dialect = default.DefaultDialect())) assert c == expected, "%s != %s" % (c, expected) - + def test_arithmetic(self): create_session().query(User) for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), (operator.sub, '-'), (operator.div, '/'), ): for (lhs, rhs, res) in ( - (5, User.id, ':users_id %s users.id'), - (5, literal(6), ':literal %s :literal_1'), - (User.id, 5, 'users.id %s :users_id'), - (User.id, literal('b'), 'users.id %s :literal'), + (5, User.id, ':id_1 %s users.id'), + (5, literal(6), ':param_1 %s :param_2'), + (User.id, 5, 'users.id %s :id_1'), + (User.id, literal('b'), 'users.id %s :param_1'), (User.id, User.id, 'users.id %s users.id'), - (literal(5), 'b', ':literal %s :literal_1'), - (literal(5), User.id, ':literal %s users.id'), - (literal(5), literal(6), ':literal %s :literal_1'), + (literal(5), 'b', ':param_1 %s :param_2'), + (literal(5), User.id, ':param_1 %s users.id'), + (literal(5), literal(6), ':param_1 %s :param_2'), ): self._test(py_op(lhs, rhs), res % sql_op) @@ -156,57 +226,65 @@ class OperatorTest(QueryTest): (operator.le, '<=', '>='), (operator.ge, '>=', '<=')): for (lhs, rhs, l_sql, r_sql) in ( - ('a', User.id, ':users_id', 'users.id'), - ('a', literal('b'), ':literal_1', ':literal'), # note swap! - (User.id, 'b', 'users.id', ':users_id'), - (User.id, literal('b'), 'users.id', ':literal'), + ('a', User.id, ':id_1', 'users.id'), + ('a', literal('b'), ':param_2', ':param_1'), # note swap! + (User.id, 'b', 'users.id', ':id_1'), + (User.id, literal('b'), 'users.id', ':param_1'), (User.id, User.id, 'users.id', 'users.id'), - (literal('a'), 'b', ':literal', ':literal_1'), - (literal('a'), User.id, ':literal', 'users.id'), - (literal('a'), literal('b'), ':literal', ':literal_1'), + (literal('a'), 'b', ':param_1', ':param_2'), + (literal('a'), User.id, ':param_1', 'users.id'), + (literal('a'), literal('b'), ':param_1', ':param_2'), ): # the compiled clause should match either (e.g.): # 'a' < 'b' -or- 'b' > 'a'. - compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect())) + compiled = str(py_op(lhs, rhs).compile(dialect=default.DefaultDialect())) fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) self.assert_(compiled == fwd_sql or compiled == rev_sql, "\n'" + compiled + "'\n does not match\n'" + fwd_sql + "'\n or\n'" + rev_sql + "'") - + + def test_op(self): + assert str(User.name.op('ilike')('17').compile(dialect=default.DefaultDialect())) == "users.name ilike :name_1" + def test_in(self): - self._test(User.id.in_('a', 'b'), "users.id IN (:users_id, :users_id_1)") - + self._test(User.id.in_(['a', 'b']), + "users.id IN (:id_1, :id_2)") + + def test_between(self): + self._test(User.id.between('a', 'b'), + "users.id BETWEEN :id_1 AND :id_2") + def test_clauses(self): for (expr, compare) in ( (func.max(User.id), "max(users.id)"), - (desc(User.id), "users.id DESC"), - (between(5, User.id, Address.id), ":literal BETWEEN users.id AND addresses.id"), + (User.id.desc(), "users.id DESC"), + (between(5, User.id, Address.id), ":param_1 BETWEEN users.id AND addresses.id"), # this one would require adding compile() to InstrumentedScalarAttribute. do we want this ? #(User.id, "users.id") ): - c = expr.compile(dialect=ansisql.ANSIDialect()) + c = expr.compile(dialect=default.DefaultDialect()) assert str(c) == compare, "%s != %s" % (str(c), compare) - - + + class CompileTest(QueryTest): def test_deferred(self): session = create_session() s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).compile() - + l = session.query(User).instances(s.execute(emailad = 'jack@bean.com')) assert [User(id=7)] == l - + class SliceTest(QueryTest): def test_first(self): assert User(id=7) == create_session().query(User).first() - + assert create_session().query(User).filter(User.id==27).first() is None - + # more slice tests are available in test/orm/generative.py - + class TextTest(QueryTest): def test_fulltext(self): assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).from_statement("select * from users").all() @@ -220,24 +298,25 @@ class TextTest(QueryTest): def test_binds(self): assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all() - + class FilterTest(QueryTest): def test_basic(self): assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).all() + @testing.fails_on('maxdb') def test_limit(self): assert [User(id=8), User(id=9)] == create_session().query(User).limit(2).offset(1).all() assert [User(id=8), User(id=9)] == list(create_session().query(User)[1:3]) assert User(id=8) == create_session().query(User)[1] - + def test_onefilter(self): assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all() def test_contains(self): """test comparing a collection to an object instance.""" - + sess = create_session() address = sess.query(Address).get(3) assert [User(id=8)] == sess.query(User).filter(User.addresses.contains(address)).all() @@ -255,9 +334,9 @@ class FilterTest(QueryTest): assert False except exceptions.InvalidRequestError: assert True - + #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() - + def test_any(self): sess = create_session() @@ -265,16 +344,49 @@ class FilterTest(QueryTest): assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'), id=4)).all() + assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'))).\ + filter(User.addresses.any(id=4)).all() + assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all() - + + @testing.fails_on_everything_except() + def test_broken_any_1(self): + sess = create_session() + + # overcorrelates + assert [User(id=7), User(id=8)] == sess.query(User).join("addresses").filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all() + + def test_broken_any_2(self): + sess = create_session() + + # works, filter is before the join + assert [User(id=7), User(id=8)] == sess.query(User).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).join("addresses", aliased=True).all() + + def test_broken_any_3(self): + sess = create_session() + + # works, filter is after the join, but reset_joinpoint is called, removing aliasing + assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(Address.email_address != None).reset_joinpoint().filter(~User.addresses.any(email_address='fred@fred.com')).all() + + @testing.fails_on_everything_except() + def test_broken_any_4(self): + sess = create_session() + + # filter is after the join, gets aliased. in 0.5 any(), has() and not contains() are shielded from aliasing + assert [User(id=10)] == sess.query(User).outerjoin("addresses", aliased=True).filter(~User.addresses.any()).all() + + @testing.unsupported('maxdb') # can core def test_has(self): sess = create_session() assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all() - + assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all() - + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() - + + dingaling = sess.query(Dingaling).get(2) + assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all() + def test_contains_m2m(self): sess = create_session() item = sess.query(Item).get(3) @@ -284,30 +396,141 @@ class FilterTest(QueryTest): def test_comparison(self): """test scalar comparison to an object instance""" - + sess = create_session() user = sess.query(User).get(8) assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user==user).all() assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all() + # generates an IS NULL + assert [] == sess.query(Address).filter(Address.user == None).all() + + assert [Order(id=5)] == sess.query(Order).filter(Order.address == None).all() + + # o2o + dingaling = sess.query(Dingaling).get(2) + assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all() + + # m2m + self.assertEquals(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) + self.assertEquals(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) + + def test_filter_by(self): + sess = create_session() + user = sess.query(User).get(8) + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter_by(user=user).all() + + # many to one generates IS NULL + assert [] == sess.query(Address).filter_by(user = None).all() + + # one to many generates WHERE NOT EXISTS + assert [User(name='chuck')] == sess.query(User).filter_by(addresses = None).all() + + def test_none_comparison(self): + sess = create_session() + + # o2o + self.assertEquals([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) + self.assertEquals([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) + + # m2o + self.assertEquals([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) + self.assertEquals([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).filter(Order.address!=None).all()) + + # o2m + self.assertEquals([User(id=10)], sess.query(User).filter(User.addresses==None).all()) + self.assertEquals([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).all()) + +class FromSelfTest(QueryTest): + def test_filter(self): + + assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().all() + + assert [User(id=8), User(id=9)] == create_session().query(User)[1:3]._from_self().all() + assert [User(id=8)] == list(create_session().query(User).filter(User.id.in_([8,9]))._from_self()[0:1]) + + def test_join(self): + assert [ + (User(id=8), Address(id=2)), + (User(id=8), Address(id=3)), + (User(id=8), Address(id=4)), + (User(id=9), Address(id=5)) + ] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().join('addresses').add_entity(Address).order_by(User.id, Address.id).all() + class AggregateTest(QueryTest): def test_sum(self): sess = create_session() - orders = sess.query(Order).filter(Order.id.in_(2, 3, 4)) + orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) assert orders.sum(Order.user_id * Order.address_id) == 79 + @testing.uses_deprecated('Call to deprecated function apply_sum') def test_apply(self): sess = create_session() - assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_(2, 3, 4)).one() == 79 - - + assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_([2, 3, 4])).one() == 79 + + def test_having(self): + sess = create_session() + assert [User(name=u'ed',id=8)] == sess.query(User).group_by([c for c in User.c]).join('addresses').having(func.count(Address.c.id)> 2).all() + + assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).group_by([c for c in User.c]).join('addresses').having(func.count(Address.c.id)< 2).all() + class CountTest(QueryTest): def test_basic(self): assert 4 == create_session().query(User).count() assert 2 == create_session().query(User).filter(users.c.name.endswith('ed')).count() +class DistinctTest(QueryTest): + def test_basic(self): + assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all() + assert [User(id=7), User(id=9), User(id=8),User(id=10)] == create_session().query(User).distinct().order_by(desc(User.name)).all() + + def test_joined(self): + """test that orderbys from a joined table get placed into the columns clause when DISTINCT is used""" + + sess = create_session() + q = sess.query(User).join('addresses').distinct().order_by(desc(Address.email_address)) + + assert [User(id=7), User(id=9), User(id=8)] == q.all() + + sess.clear() + + # test that it works on embedded eagerload/LIMIT subquery + q = sess.query(User).join('addresses').distinct().options(eagerload('addresses')).order_by(desc(Address.email_address)).limit(2) + + def go(): + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + ] == q.all() + self.assert_sql_count(testing.db, go, 1) + + +class YieldTest(QueryTest): + def test_basic(self): + import gc + sess = create_session() + q = iter(sess.query(User).yield_per(1).from_statement("select * from users")) + + ret = [] + self.assertEquals(len(sess.identity_map), 0) + ret.append(q.next()) + ret.append(q.next()) + self.assertEquals(len(sess.identity_map), 2) + ret.append(q.next()) + ret.append(q.next()) + self.assertEquals(len(sess.identity_map), 4) + try: + q.next() + assert False + except StopIteration: + pass + class TextTest(QueryTest): def test_fulltext(self): assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).from_statement("select * from users").all() @@ -321,13 +544,13 @@ class TextTest(QueryTest): def test_binds(self): assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all() - - + + class ParentTest(QueryTest): def test_o2m(self): sess = create_session() q = sess.query(User) - + u1 = q.filter_by(name='jack').one() # test auto-lookup of property @@ -346,10 +569,14 @@ class ParentTest(QueryTest): o = sess.query(Order).with_parent(u1).filter(orders.c.id>2).all() assert [Order(description="order 3"), Order(description="order 5")] == o + # test against None for parent? this can't be done with the current API since we don't know + # what mapper to use + #assert sess.query(Order).with_parent(None, property='addresses').all() == [Order(description="order 5")] + def test_noparent(self): sess = create_session() q = sess.query(User) - + u1 = q.filter_by(name='jack').one() try: @@ -363,10 +590,32 @@ class ParentTest(QueryTest): i1 = sess.query(Item).filter_by(id=2).one() k = sess.query(Keyword).with_parent(i1).all() assert [Keyword(name='red'), Keyword(name='small'), Keyword(name='square')] == k - - + + class JoinTest(QueryTest): + def test_getjoinable_tables(self): + sess = create_session() + + sel1 = select([users]).alias() + sel2 = select([users], from_obj=users.join(addresses)).alias() + + j1 = sel1.join(users, sel1.c.id==users.c.id) + j2 = j1.join(addresses) + + for from_obj, assert_cond in ( + (users, [users]), + (users.join(addresses), [users, addresses]), + (sel1, [sel1]), + (sel2, [sel2]), + (sel1.join(users, sel1.c.id==users.c.id), [sel1, users]), + (sel2.join(users, sel2.c.id==users.c.id), [sel2, users]), + (j2, [j1, j2, sel1, users, addresses]) + + ): + ret = set(sess.query(User).select_from(from_obj)._get_joinable_tables()) + self.assertEquals(ret, set(assert_cond).union([from_obj]), [x.description for x in ret]) + def test_overlapping_paths(self): for aliased in (True,False): # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack) @@ -376,7 +625,120 @@ class JoinTest(QueryTest): def test_overlapping_paths_outerjoin(self): result = create_session().query(User).outerjoin(['orders', 'items']).filter_by(id=3).outerjoin(['orders','address']).filter_by(id=1).all() assert [User(id=7, name='jack')] == result + + def test_from_joinpoint(self): + sess = create_session() + + for oalias,ialias in [(True, True), (False, False), (True, False), (False, True)]: + self.assertEquals( + sess.query(User).join('orders', aliased=oalias).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description == 'item 4').all(), + [User(name='jack')] + ) + # use middle criterion + self.assertEquals( + sess.query(User).join('orders', aliased=oalias).filter(Order.user_id==9).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description=='item 4').all(), + [] + ) + + orderalias = aliased(Order) + itemalias = aliased(Item) + self.assertEquals( + sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(itemalias.description == 'item 4').all(), + [User(name='jack')] + ) + self.assertEquals( + sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(orderalias.user_id==9).filter(itemalias.description=='item 4').all(), + [] + ) + + def test_orderby_arg_bug(self): + sess = create_session() + + # no arg error + result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all() + + def test_aliased_classes(self): + sess = create_session() + + (user7, user8, user9, user10) = sess.query(User).all() + (address1, address2, address3, address4, address5) = sess.query(Address).all() + expected = [(user7, address1), + (user8, address2), + (user8, address3), + (user8, address4), + (user9, address5), + (user10, None)] + + q = sess.query(User) + AdAlias = aliased(Address) + q = q.add_entity(AdAlias).select_from(outerjoin(User, AdAlias)) + l = q.all() + self.assertEquals(l, expected) + + sess.clear() + + q = sess.query(User).add_entity(AdAlias) + l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all() + self.assertEquals(l, [(user8, address3)]) + + + l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all() + self.assertEquals(l, [(user8, address3)]) + + l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all() + self.assertEquals(l, [(user8, address3)]) + + def test_aliased_classes_m2m(self): + sess = create_session() + + (order1, order2, order3, order4, order5) = sess.query(Order).all() + (item1, item2, item3, item4, item5) = sess.query(Item).all() + expected = [ + (order1, item1), + (order1, item2), + (order1, item3), + (order2, item1), + (order2, item2), + (order2, item3), + (order3, item3), + (order3, item4), + (order3, item5), + (order4, item1), + (order4, item5), + (order5, item5), + ] + + q = sess.query(Order) + q = q.add_entity(Item).select_from(join(Order, Item, 'items')).order_by(Order.id, Item.id) + l = q.all() + self.assertEquals(l, expected) + + IAlias = aliased(Item) + q = sess.query(Order, IAlias).select_from(join(Order, IAlias, 'items')).filter(IAlias.description=='item 3') + l = q.all() + self.assertEquals(l, + [ + (order1, item3), + (order2, item3), + (order3, item3), + ] + ) + + def test_generative_join(self): + # test that alised_ids is copied + sess = create_session() + q = sess.query(User).add_entity(Address) + q1 = q.join('addresses', aliased=True) + q2 = q.join('addresses', aliased=True) + q3 = q2.join('addresses', aliased=True) + q4 = q2.join('addresses', aliased=True, id='someid') + q5 = q2.join('addresses', aliased=True, id='someid') + q6 = q5.join('addresses', aliased=True, id='someid') + assert q1._alias_ids[class_mapper(Address)] != q2._alias_ids[class_mapper(Address)] + assert q2._alias_ids[class_mapper(Address)] != q3._alias_ids[class_mapper(Address)] + assert q4._alias_ids['someid'] != q5._alias_ids['someid'] + def test_reset_joinpoint(self): for aliased in (True, False): # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack) @@ -389,10 +751,10 @@ class JoinTest(QueryTest): def test_overlap_with_aliases(self): oalias = orders.alias('oalias') - result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).all() + result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_(["order 1", "order 2", "order 3"])).join(['orders', 'items']).all() assert [User(id=7, name='jack'), User(id=9, name='fred')] == result - - result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).filter_by(id=4).all() + + result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_(["order 1", "order 2", "order 3"])).join(['orders', 'items']).filter_by(id=4).all() assert [User(id=7, name='jack')] == result def test_aliased(self): @@ -407,6 +769,9 @@ class JoinTest(QueryTest): q = sess.query(User).join('addresses', aliased=True).filter(Address.email_address=='jack@bean.com') assert [User(id=7)] == q.all() + q = sess.query(User).join('addresses', aliased=True).filter(or_(Address.email_address=='jack@bean.com', Address.email_address=='fred@fred.com')) + assert [User(id=7), User(id=9)] == q.all() + # test two aliasized paths, one to 'orders' and the other to 'orders','items'. # one row is returned because user 7 has order 3 and also has order 1 which has item 1 # this tests a o2m join and a m2m join. @@ -420,7 +785,7 @@ class JoinTest(QueryTest): q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Order.description=="item 1") assert [] == q.all() assert q.count() == 0 - + q = sess.query(User).join('orders', aliased=True).filter(Order.items.any(Item.description=='item 4')) assert [User(id=7)] == q.all() @@ -492,8 +857,7 @@ class MultiplePathTest(ORMTest): create_session().query(T1).join('t2s_1', aliased=True).filter(t2.c.id==5).reset_joinpoint().join('t2s_2').all() create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2', aliased=True).all() - - + class SynonymTest(QueryTest): keep_mappers = True @@ -536,65 +900,91 @@ class SynonymTest(QueryTest): ): sess = create_session() q = sess.query(User) - + u1 = q.filter_by(**{nameprop:'jack'}).one() o = sess.query(Order).with_parent(u1, property=orderprop).all() assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o - -class InstancesTest(QueryTest): + +class InstancesTest(QueryTest, AssertsCompiledSQL): def test_from_alias(self): query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.id', addresses.c.id]) - q = create_session().query(User) + sess =create_session() + q = sess.query(User) def go(): l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute()) assert fixtures.user_address_result == l - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) + sess.clear() def go(): l = q.options(contains_alias('ulist'), contains_eager('addresses')).from_statement(query).all() assert fixtures.user_address_result == l - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) def test_contains_eager(self): + sess = create_session() + q = sess.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses)) + self.assert_compile(q.statement, "SELECT addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, " + "addresses.email_address AS addresses_email_address, users.id AS users_id, users.name AS users_name "\ + "FROM users LEFT OUTER JOIN addresses ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect()) + + def go(): + assert fixtures.user_address_result == q.all() + self.assert_sql_count(testing.db, go, 1) + sess.clear() + + adalias = addresses.alias() + q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias)) + def go(): + assert fixtures.user_address_result == q.all() + self.assert_sql_count(testing.db, go, 1) + sess.clear() + selectquery = users.outerjoin(addresses).select(users.c.id<10, use_labels=True, order_by=[users.c.id, addresses.c.id]) - q = create_session().query(User) + q = sess.query(User) def go(): l = q.options(contains_eager('addresses')).instances(selectquery.execute()) assert fixtures.user_address_result[0:3] == l - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() def go(): l = q.options(contains_eager('addresses')).from_statement(selectquery).all() assert fixtures.user_address_result[0:3] == l - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) def test_contains_eager_alias(self): adalias = addresses.alias('adalias') selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id]) - q = create_session().query(User) + sess = create_session() + q = sess.query(User) def go(): # test using a string alias name l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute()) assert fixtures.user_address_result == l - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) + sess.clear() def go(): # test using the Alias object itself l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute()) assert fixtures.user_address_result == l - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() def decorate(row): d = {} - for c in addresses.columns: + for c in addresses.c: d[c] = row[adalias.corresponding_column(c)] return d @@ -602,13 +992,68 @@ class InstancesTest(QueryTest): # test using a custom 'decorate' function l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute()) assert fixtures.user_address_result == l - self.assert_sql_count(testbase.db, go, 1) + self.assert_sql_count(testing.db, go, 1) + sess.clear() - def test_multi_mappers(self): + oalias = orders.alias('o1') + ialias = items.alias('i1') + query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id).order_by(oalias.c.id).order_by(ialias.c.id) + q = create_session().query(User) + # test using string alias with more than one level deep + def go(): + l = q.options(contains_eager('orders', alias='o1'), contains_eager('orders.items', alias='i1')).instances(query.execute()) + assert fixtures.user_order_result == l + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + + # test using Alias with more than one level deep + def go(): + l = q.options(contains_eager('orders', alias=oalias), contains_eager('orders.items', alias=ialias)).instances(query.execute()) + assert fixtures.user_order_result == l + self.assert_sql_count(testing.db, go, 1) + sess.clear() + + def test_values(self): sess = create_session() - (user7, user8, user9, user10) = sess.query(User).all() - (address1, address2, address3, address4, address5) = sess.query(Address).all() + sel = users.select(User.id.in_([7, 8])).alias() + q = sess.query(User) + q2 = q.select_from(sel).values(User.name) + self.assertEquals(list(q2), [(u'jack',), (u'ed',)]) + + q = sess.query(User) + q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String)) + self.assertEquals(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) + + q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%'))) + self.assertEquals(list(q2), [(True, 1), (False, 3)]) + + q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address) + self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) + + q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address))[1:3].values(User.name, Address.email_address) + self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')]) + + q2 = q.join('addresses', aliased=True).filter(User.name.like('%e%')).values(User.name, Address.email_address) + self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) + + q2 = q.values(func.count(User.name)) + assert q2.next() == (4,) + + u2 = users.alias() + q2 = q.select_from(sel).filter(u2.c.id>1).order_by([users.c.id, sel.c.id, u2.c.id]).values(users.c.name, sel.c.name, u2.c.name) + self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) + + q2 = q.select_from(sel).filter(users.c.id>1).values(users.c.name, sel.c.name, User.name) + self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'ed', u'ed', u'ed')]) + + def test_multi_mappers(self): + + test_session = create_session() + + (user7, user8, user9, user10) = test_session.query(User).all() + (address1, address2, address3, address4, address5) = test_session.query(Address).all() # note the result is a cartesian product expected = [(user7, address1), @@ -617,27 +1062,36 @@ class InstancesTest(QueryTest): (user8, address4), (user9, address5), (user10, None)] - + + sess = create_session() + selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id]) q = sess.query(User) l = q.instances(selectquery.execute(), Address) assert l == expected + sess.clear() + for aliased in (False, True): q = sess.query(User) + q = q.add_entity(Address).outerjoin('addresses', aliased=aliased) l = q.all() assert l == expected + sess.clear() q = sess.query(User).add_entity(Address) l = q.join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com').all() assert l == [(user8, address3)] - + sess.clear() + q = sess.query(User, Address).join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com') assert q.all() == [(user8, address3)] + sess.clear() q = sess.query(User, Address).join('addresses', aliased=aliased).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com') - assert q.all() == [(user8, address3)] + self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)]) + sess.clear() def test_aliased_multi_mappers(self): sess = create_session() @@ -645,45 +1099,66 @@ class InstancesTest(QueryTest): (user7, user8, user9, user10) = sess.query(User).all() (address1, address2, address3, address4, address5) = sess.query(Address).all() - # note the result is a cartesian product expected = [(user7, address1), (user8, address2), (user8, address3), (user8, address4), (user9, address5), (user10, None)] - + q = sess.query(User) adalias = addresses.alias('adalias') q = q.add_entity(Address, alias=adalias).select_from(users.outerjoin(adalias)) l = q.all() assert l == expected + sess.clear() + q = sess.query(User).add_entity(Address, alias=adalias) l = q.select_from(users.outerjoin(adalias)).filter(adalias.c.email_address=='ed@bettyboop.com').all() assert l == [(user8, address3)] - + def test_multi_columns(self): + sess = create_session() + + expected = [(u, u.name) for u in sess.query(User).all()] + + for add_col in (User.name, users.c.name, User.c.name): + assert sess.query(User).add_column(add_col).all() == expected + sess.clear() + + self.assertRaises(exceptions.InvalidRequestError, sess.query(User).add_column, object()) + + def test_ambiguous_column(self): + sess = create_session() + + q = sess.query(User).join('addresses', aliased=True).join('addresses', aliased=True).add_column(Address.id) + self.assertRaises(exceptions.InvalidRequestError, iter, q) + + def test_multi_columns_2(self): """test aliased/nonalised joins with the usage of add_column()""" sess = create_session() + (user7, user8, user9, user10) = sess.query(User).all() expected = [(user7, 1), (user8, 3), (user9, 1), (user10, 0) ] - + for aliased in (False, True): q = sess.query(User) q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(Address.id).label('count')) l = q.all() assert l == expected + sess.clear() s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id) q = sess.query(User) l = q.add_column("count").from_statement(s).all() assert l == expected + def test_two_columns(self): sess = create_session() (user7, user8, user9, user10) = sess.query(User).all() @@ -693,26 +1168,202 @@ class InstancesTest(QueryTest): (user9, 1, "Name:fred"), (user10, 0, "Name:chuck")] + q = create_session().query(User).add_column(func.count(addresses.c.id))\ + .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=True)\ + .group_by([c for c in users.c]).order_by(users.c.id) + + assert q.all() == expected + # test with a straight statement s = select([users, func.count(addresses.c.id).label('count'), ("Name:" + users.c.name).label('concat')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.id]) q = create_session().query(User) l = q.add_column("count").add_column("concat").from_statement(s).all() assert l == expected - + + sess.clear() + # test with select_from() q = create_session().query(User).add_column(func.count(addresses.c.id))\ .add_column(("Name:" + users.c.name)).select_from(users.outerjoin(addresses))\ .group_by([c for c in users.c]).order_by(users.c.id) - + assert q.all() == expected + sess.clear() # test with outerjoin() both aliased and non for aliased in (False, True): q = create_session().query(User).add_column(func.count(addresses.c.id))\ .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=aliased)\ .group_by([c for c in users.c]).order_by(users.c.id) - + assert q.all() == expected + sess.clear() + + +class SelectFromTest(QueryTest): + keep_mappers = False + + def setup_mappers(self): + pass + + def test_replace_with_select(self): + mapper(User, users, properties = { + 'addresses':relation(Address) + }) + mapper(Address, addresses) + + sel = users.select(users.c.id.in_([7, 8])).alias() + sess = create_session() + + self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)]) + + self.assertEquals(sess.query(User).select_from(sel).filter(User.c.id==8).all(), [User(id=8)]) + + self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [ + User(name='jack',id=7), User(name='ed',id=8) + ]) + + self.assertEquals(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [ + User(name='ed',id=8), User(name='jack',id=7) + ]) + + self.assertEquals(sess.query(User).select_from(sel).options(eagerload('addresses')).first(), + User(name='jack', addresses=[Address(id=1)]) + ) + + def test_join_mapper_order_by(self): + mapper(User, users, order_by=users.c.id) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + self.assertEquals(sess.query(User).select_from(sel).all(), + [ + User(name='jack',id=7), User(name='ed',id=8) + ] + ) + + def test_join_no_order_by(self): + mapper(User, users) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + self.assertEquals(sess.query(User).select_from(sel).all(), + [ + User(name='jack',id=7), User(name='ed',id=8) + ] + ) + + def test_join(self): + mapper(User, users, properties = { + 'addresses':relation(Address) + }) + mapper(Address, addresses) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + self.assertEquals(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(), + [ + (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@bettyboop.com',id=3)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@lala.com',id=4)) + ] + ) + + self.assertEquals(sess.query(User).select_from(sel).join('addresses', aliased=True).add_entity(Address).order_by(User.id).order_by(Address.id).all(), + [ + (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@bettyboop.com',id=3)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@lala.com',id=4)) + ] + ) + + + def test_more_joins(self): + mapper(User, users, properties={ + 'orders':relation(Order, backref='user'), # o2m, m2o + }) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m + }) + mapper(Item, items, properties={ + 'keywords':relation(Keyword, secondary=item_keywords, order_by=keywords.c.id) #m2m + }) + mapper(Keyword, keywords) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + self.assertEquals(sess.query(User).select_from(sel).join(['orders', 'items', 'keywords']).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + User(name=u'jack',id=7) + ]) + + self.assertEquals(sess.query(User).select_from(sel).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + User(name=u'jack',id=7) + ]) + + def go(): + self.assertEquals(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + User(name=u'jack',orders=[ + Order(description=u'order 1',items=[ + Item(description=u'item 1',keywords=[Keyword(name=u'red'), Keyword(name=u'big'), Keyword(name=u'round')]), + Item(description=u'item 2',keywords=[Keyword(name=u'red',id=2), Keyword(name=u'small',id=5), Keyword(name=u'square')]), + Item(description=u'item 3',keywords=[Keyword(name=u'green',id=3), Keyword(name=u'big',id=4), Keyword(name=u'round',id=6)]) + ]), + Order(description=u'order 3',items=[ + Item(description=u'item 3',keywords=[Keyword(name=u'green',id=3), Keyword(name=u'big',id=4), Keyword(name=u'round',id=6)]), + Item(description=u'item 4',keywords=[],id=4), + Item(description=u'item 5',keywords=[],id=5) + ]), + Order(description=u'order 5',items=[Item(description=u'item 5',keywords=[])])]) + ]) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + sel2 = orders.select(orders.c.id.in_([1,2,3])) + self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').all(), [ + Order(description=u'order 1',id=1), + Order(description=u'order 2',id=2), + ]) + self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').all(), [ + Order(description=u'order 1',id=1), + Order(description=u'order 2',id=2), + ]) + + + def test_replace_with_eager(self): + mapper(User, users, properties = { + 'addresses':relation(Address) + }) + mapper(Address, addresses) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + def go(): + self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).all(), + [ + User(id=7, addresses=[Address(id=1)]), + User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + sess.clear() + + def go(): + self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.c.id==8).all(), + [User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])] + ) + self.assert_sql_count(testing.db, go, 1) + sess.clear() + + def go(): + self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])) + self.assert_sql_count(testing.db, go, 1) class CustomJoinTest(QueryTest): keep_mappers = False @@ -736,21 +1387,26 @@ class CustomJoinTest(QueryTest): assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all() -class SelfReferentialJoinTest(ORMTest): +class SelfReferentialTest(ORMTest): + keep_mappers = True + keep_data = True + def define_tables(self, metadata): global nodes nodes = Table('nodes', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('nodes.id')), Column('data', String(30))) - - def test_join(self): + + def insert_data(self): + global Node + class Node(Base): def append(self, node): self.children.append(node) mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=True, join_depth=3, + 'children':relation(Node, lazy=True, join_depth=3, backref=backref('parent', remote_side=[nodes.c.id]) ) }) @@ -764,21 +1420,146 @@ class SelfReferentialJoinTest(ORMTest): n1.children[1].append(Node(data='n123')) sess.save(n1) sess.flush() - sess.clear() + sess.close() - # TODO: the aliasing of the join in query._join_to has to limit the aliasing - # among local_side / remote_side (add local_side as an attribute on PropertyLoader) - # also implement this idea in EagerLoader + def test_join(self): + sess = create_session() + node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first() assert node.data=='n12' node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first() assert node.data=='n1' - + node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\ join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first() assert node.data == 'n122' + + def test_explicit_join(self): + sess = create_session() + + n1 = aliased(Node) + n2 = aliased(Node) + + node = sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data=='n122').first() + assert node.data=='n12' + + node = sess.query(Node).select_from(join(Node, n1, 'children').join(n2, 'children')).\ + filter(n2.data=='n122').first() + assert node.data=='n1' + + # mix explicit and named onclauses + node = sess.query(Node).select_from(join(Node, n1, Node.id==n1.parent_id).join(n2, 'children')).\ + filter(n2.data=='n122').first() + assert node.data=='n1' + + node = sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\ + filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first() + assert node.data == 'n122' + + self.assertEquals( + list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\ + filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)), + [('n122', 'n12', 'n1')]) + + def test_any(self): + sess = create_session() + + self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) + self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) + self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) + + def test_has(self): + sess = create_session() + + self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) + self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) + + def test_contains(self): + sess = create_session() + + n122 = sess.query(Node).filter(Node.data=='n122').one() + self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')]) + + n13 = sess.query(Node).filter(Node.data=='n13').one() + self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')]) + + def test_eq_ne(self): + sess = create_session() + + n12 = sess.query(Node).filter(Node.data=='n12').one() + self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + + self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) + +class SelfReferentialM2MTest(ORMTest): + keep_mappers = True + keep_data = True + + def define_tables(self, metadata): + global nodes, node_to_nodes + nodes = Table('nodes', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('data', String(30))) + + node_to_nodes =Table('node_to_nodes', metadata, + Column('left_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), + Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), + ) + + def insert_data(self): + global Node + + class Node(Base): + pass + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=True, secondary=node_to_nodes, + primaryjoin=nodes.c.id==node_to_nodes.c.left_node_id, + secondaryjoin=nodes.c.id==node_to_nodes.c.right_node_id, + ) + }) + sess = create_session() + n1 = Node(data='n1') + n2 = Node(data='n2') + n3 = Node(data='n3') + n4 = Node(data='n4') + n5 = Node(data='n5') + n6 = Node(data='n6') + n7 = Node(data='n7') + + n1.children = [n2, n3, n4] + n2.children = [n3, n6, n7] + n3.children = [n5, n4] + + sess.save(n1) + sess.save(n2) + sess.save(n3) + sess.save(n4) + sess.flush() + sess.close() + + def test_any(self): + sess = create_session() + self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')]) + + def test_contains(self): + sess = create_session() + n4 = sess.query(Node).filter_by(data='n4').one() + + self.assertEquals(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')]) + self.assertEquals(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')]) + + def test_explicit_join(self): + sess = create_session() + + n1 = aliased(Node) + self.assertEquals( + sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data.in_(['n3', 'n7'])).all(), + [Node(data='n1'), Node(data='n2')] + ) + class ExternalColumnsTest(QueryTest): keep_mappers = False @@ -786,40 +1567,32 @@ class ExternalColumnsTest(QueryTest): pass def test_external_columns_bad(self): - """test that SA catches some common mis-configurations of external columns.""" - f = (users.c.id * 2) - try: - mapper(User, users, properties={ - 'concat': f, - }) - class_mapper(User) - except exceptions.ArgumentError, e: - assert str(e) == "Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(f) - else: - raise 'expected ArgumentError' + + self.assertRaisesMessage(exceptions.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={ + 'concat': (users.c.id * 2), + }) clear_mappers() - try: - mapper(User, users, properties={ - 'concat': column_property(users.c.id * 2), - }) - except exceptions.ArgumentError, e: - assert str(e) == 'ColumnProperties must be named for the mapper to work with them. Try .label() to fix this' - else: - raise 'expected ArgumentError' + + self.assertRaisesMessage(exceptions.ArgumentError, "must be given a ColumnElement as its argument.", column_property, + select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users) + ) def test_external_columns_good(self): """test querying mappings that reference external columns or selectables.""" + mapper(User, users, properties={ - 'concat': column_property((users.c.id * 2).label('concat')), - 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users).label('count')) + 'concat': column_property((users.c.id * 2)), + 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users).as_scalar()) }) mapper(Address, addresses, properties={ 'user':relation(User, lazy=True) - }) + }) sess = create_session() - l = sess.query(User).select() + + + l = sess.query(User).all() assert [ User(id=7, concat=14, count=1), User(id=8, concat=16, count=3), @@ -834,23 +1607,47 @@ class ExternalColumnsTest(QueryTest): Address(id=4, user=User(id=8, concat=16, count=3)), Address(id=5, user=User(id=9, concat=18, count=1)) ] - - assert address_result == sess.query(Address).all() - + + self.assertEquals(sess.query(Address).all(), address_result) + # run the eager version twice to test caching of aliased clauses for x in range(2): sess.clear() def go(): - assert address_result == sess.query(Address).options(eagerload('user')).all() - self.assert_sql_count(testbase.db, go, 1) - + self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result) + self.assert_sql_count(testing.db, go, 1) + tuple_address_result = [(address, address.user) for address in address_result] - tuple_address_result == sess.query(Address).join('user').add_entity(User).all() + q =sess.query(Address).join('user', aliased=True, id='ualias').join('user', aliased=True).add_column(User.concat) + self.assertRaisesMessage(exceptions.InvalidRequestError, "Ambiguous", q.all) + + self.assertEquals(sess.query(Address).join('user', aliased=True, id='ualias').add_entity(User, id='ualias').all(), tuple_address_result) + + self.assertEquals(sess.query(Address).join('user', aliased=True, id='ualias').join('user', aliased=True).\ + add_column(User.concat, id='ualias').add_column(User.count, id='ualias').all(), + [ + (Address(id=1), 14, 1), + (Address(id=2), 16, 3), + (Address(id=3), 16, 3), + (Address(id=4), 16, 3), + (Address(id=5), 18, 1) + ] + ) - assert tuple_address_result == sess.query(Address).join('user', aliased=True, id='ualias').add_entity(User, id='ualias').all() + self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), + [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] + ) -if __name__ == '__main__': - testbase.main() + self.assertEquals(list(sess.query(Address).join('user', aliased=True).values(Address.id, User.id, User.concat, User.count)), + [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] + ) + + ua = aliased(User) + self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), + [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] + ) +if __name__ == '__main__': + testenv.main() diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 9fca22b244..40773f8359 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -1,46 +1,54 @@ -import testbase +import testenv; testenv.configure_for_tests() import datetime from sqlalchemy import * +from sqlalchemy import exceptions, types from sqlalchemy.orm import * from sqlalchemy.orm import collections from sqlalchemy.orm.collections import collection from testlib import * +from testlib import fixtures + +class RelationTest(TestBase): + """An extended topological sort test + + This is essentially an extension of the "dependency.py" topological sort + test. In this test, a table is dependent on two other tables that are + otherwise unrelated to each other. The dependency sort must insure that + this childmost table is below both parent tables in the outcome (a bug + existed where this was not always the case). + + While the straight topological sort tests should expose this, since the + sorting can be different due to subtle differences in program execution, + this test case was exposing the bug whereas the simpler tests were not. + """ -class RelationTest(PersistTest): - """this is essentially an extension of the "dependency.py" topological sort test. - in this test, a table is dependent on two other tables that are otherwise unrelated to each other. - the dependency sort must insure that this childmost table is below both parent tables in the outcome - (a bug existed where this was not always the case). - while the straight topological sort tests should expose this, since the sorting can be different due - to subtle differences in program execution, this test case was exposing the bug whereas the simpler tests - were not.""" def setUpAll(self): global metadata, tbl_a, tbl_b, tbl_c, tbl_d metadata = MetaData() tbl_a = Table("tbl_a", metadata, Column("id", Integer, primary_key=True), - Column("name", String), + Column("name", String(128)), ) tbl_b = Table("tbl_b", metadata, Column("id", Integer, primary_key=True), - Column("name", String), + Column("name", String(128)), ) tbl_c = Table("tbl_c", metadata, Column("id", Integer, primary_key=True), Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False), - Column("name", String), + Column("name", String(128)), ) tbl_d = Table("tbl_d", metadata, Column("id", Integer, primary_key=True), Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False), Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), - Column("name", String), + Column("name", String(128)), ) def setUp(self): global session - session = create_session(bind=testbase.db) - conn = testbase.db.connect() + session = create_session(bind=testing.db) + conn = testing.db.connect() conn.create(tbl_a) conn.create(tbl_b) conn.create(tbl_c) @@ -57,11 +65,11 @@ class RelationTest(PersistTest): D.mapper = mapper(D, tbl_d) C.mapper = mapper(C, tbl_c, properties=dict( - d_rows=relation(D, private=True, backref="c_row"), + d_rows=relation(D, cascade="all, delete-orphan", backref="c_row"), )) B.mapper = mapper(B, tbl_b) A.mapper = mapper(A, tbl_a, properties=dict( - c_rows=relation(C, private=True, backref="a_row"), + c_rows=relation(C, cascade="all, delete-orphan", backref="a_row"), )) D.mapper.add_property("b_row", relation(B)) @@ -76,35 +84,39 @@ class RelationTest(PersistTest): d3 = D(); d3.name = "d3"; d3.b_row = b; d3.c_row = c session.save_or_update(a) session.save_or_update(b) - + def tearDown(self): - conn = testbase.db.connect() + conn = testing.db.connect() conn.drop(tbl_d) conn.drop(tbl_c) conn.drop(tbl_b) conn.drop(tbl_a) def tearDownAll(self): - metadata.drop_all(testbase.db) - + metadata.drop_all(testing.db) + def testDeleteRootTable(self): session.flush() session.delete(a) # works as expected session.flush() - + def testDeleteMiddleTable(self): session.flush() session.delete(c) # fails session.flush() - -class RelationTest2(PersistTest): - """this test tests a relationship on a column that is included in multiple foreign keys, - as well as a self-referential relationship on a composite key where one column in the foreign key - is 'joined to itself'.""" + +class RelationTest2(TestBase): + """Tests a relationship on a column included in multiple foreign keys. + + This test tests a relationship on a column that is included in multiple + foreign keys, as well as a self-referential relationship on a composite + key where one column in the foreign key is 'joined to itself'. + """ + def setUpAll(self): global metadata, company_tbl, employee_tbl - metadata = MetaData(testbase.db) - + metadata = MetaData(testing.db) + company_tbl = Table('company', metadata, Column('company_id', Integer, primary_key=True), Column('name', Unicode(30))) @@ -118,12 +130,13 @@ class RelationTest2(PersistTest): ForeignKeyConstraint(['company_id', 'reports_to_id'], ['employee.company_id', 'employee.emp_id'])) metadata.create_all() - + def tearDownAll(self): - metadata.drop_all() + metadata.drop_all() - def testexplicit(self): + def test_explicit(self): """test with mappers that have fairly explicit join conditions""" + class Company(object): pass class Employee(object): @@ -132,7 +145,7 @@ class RelationTest2(PersistTest): self.company = company self.emp_id = emp_id self.reports_to = reports_to - + mapper(Company, company_tbl) mapper(Employee, employee_tbl, properties= { 'company':relation(Company, primaryjoin=employee_tbl.c.company_id==company_tbl.c.company_id, backref='employees'), @@ -140,8 +153,9 @@ class RelationTest2(PersistTest): and_( employee_tbl.c.emp_id==employee_tbl.c.reports_to_id, employee_tbl.c.company_id==employee_tbl.c.company_id - ), - foreignkey=[employee_tbl.c.company_id, employee_tbl.c.emp_id], + ), + remote_side=[employee_tbl.c.emp_id, employee_tbl.c.company_id], + foreign_keys=[employee_tbl.c.reports_to_id], backref='employees') }) @@ -149,13 +163,13 @@ class RelationTest2(PersistTest): c1 = Company() c2 = Company() - e1 = Employee('emp1', c1, 1) - e2 = Employee('emp2', c1, 2, e1) - e3 = Employee('emp3', c1, 3, e1) - e4 = Employee('emp4', c1, 4, e3) - e5 = Employee('emp5', c2, 1) - e6 = Employee('emp6', c2, 2, e5) - e7 = Employee('emp7', c2, 3, e5) + e1 = Employee(u'emp1', c1, 1) + e2 = Employee(u'emp2', c1, 2, e1) + e3 = Employee(u'emp3', c1, 3, e1) + e4 = Employee(u'emp4', c1, 4, e3) + e5 = Employee(u'emp5', c2, 1) + e6 = Employee(u'emp6', c2, 2, e5) + e7 = Employee(u'emp7', c2, 3, e5) [sess.save(x) for x in [c1,c2]] sess.flush() @@ -163,14 +177,14 @@ class RelationTest2(PersistTest): test_c1 = sess.query(Company).get(c1.company_id) test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) - assert test_e1.name == 'emp1' + assert test_e1.name == 'emp1', test_e1.name test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) - assert test_e5.name == 'emp5' + assert test_e5.name == 'emp5', test_e5.name assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' - def testimplicit(self): + def test_implicit(self): """test with mappers that have the most minimal arguments""" class Company(object): pass @@ -184,8 +198,9 @@ class RelationTest2(PersistTest): mapper(Company, company_tbl) mapper(Employee, employee_tbl, properties= { 'company':relation(Company, backref='employees'), - 'reports_to':relation(Employee, - foreignkey=[employee_tbl.c.company_id, employee_tbl.c.emp_id], + 'reports_to':relation(Employee, + remote_side=[employee_tbl.c.emp_id, employee_tbl.c.company_id], + foreign_keys=[employee_tbl.c.reports_to_id], backref='employees') }) @@ -193,13 +208,13 @@ class RelationTest2(PersistTest): c1 = Company() c2 = Company() - e1 = Employee('emp1', c1, 1) - e2 = Employee('emp2', c1, 2, e1) - e3 = Employee('emp3', c1, 3, e1) - e4 = Employee('emp4', c1, 4, e3) - e5 = Employee('emp5', c2, 1) - e6 = Employee('emp6', c2, 2, e5) - e7 = Employee('emp7', c2, 3, e5) + e1 = Employee(u'emp1', c1, 1) + e2 = Employee(u'emp2', c1, 2, e1) + e3 = Employee(u'emp3', c1, 3, e1) + e4 = Employee(u'emp4', c1, 4, e3) + e5 = Employee(u'emp5', c2, 1) + e6 = Employee(u'emp6', c2, 2, e5) + e7 = Employee(u'emp7', c2, 3, e5) [sess.save(x) for x in [c1,c2]] sess.flush() @@ -207,18 +222,18 @@ class RelationTest2(PersistTest): test_c1 = sess.query(Company).get(c1.company_id) test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) - assert test_e1.name == 'emp1' + assert test_e1.name == 'emp1', test_e1.name test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) - assert test_e5.name == 'emp5' + assert test_e5.name == 'emp5', test_e5.name assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' - -class RelationTest3(PersistTest): + +class RelationTest3(TestBase): def setUpAll(self): global jobs, pageversions, pages, metadata, Job, Page, PageVersion, PageComment import datetime - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) jobs = Table("jobs", metadata, Column("jobno", Unicode(15), primary_key=True), Column("created", DateTime, nullable=False, default=datetime.datetime.now), @@ -242,8 +257,8 @@ class RelationTest3(PersistTest): pagecomments = Table("pagecomments", metadata, Column("jobno", Unicode(15), primary_key=True), Column("pagename", Unicode(30), primary_key=True), - Column("comment_id", Integer, primary_key=True), - Column("content", Unicode), + Column("comment_id", Integer, primary_key=True, autoincrement=False), + Column("content", UnicodeText), ForeignKeyConstraint(["jobno", "pagename"], ["pages.jobno", "pages.pagename"]) ) @@ -284,7 +299,7 @@ class RelationTest3(PersistTest): mapper(Page, pages, properties={ 'job': relation(Job, backref=backref('pages', cascade="all, delete-orphan", order_by=pages.c.pagename)), 'currentversion': relation(PageVersion, - foreignkey=pages.c.current_version, + foreign_keys=[pages.c.current_version], primaryjoin=and_(pages.c.jobno==pageversions.c.jobno, pages.c.pagename==pageversions.c.pagename, pages.c.current_version==pageversions.c.version), @@ -309,19 +324,19 @@ class RelationTest3(PersistTest): def tearDownAll(self): clear_mappers() - metadata.drop_all() + metadata.drop_all() def testbasic(self): """test the combination of complicated join conditions with post_update""" - j1 = Job('somejob') - j1.create_page('page1') - j1.create_page('page2') - j1.create_page('page3') + j1 = Job(u'somejob') + j1.create_page(u'page1') + j1.create_page(u'page2') + j1.create_page(u'page3') - j2 = Job('somejob2') - j2.create_page('page1') - j2.create_page('page2') - j2.create_page('page3') + j2 = Job(u'somejob2') + j2.create_page(u'page1') + j2.create_page(u'page2') + j2.create_page(u'page3') j2.pages[0].add_version() j2.pages[0].add_version() @@ -333,17 +348,18 @@ class RelationTest3(PersistTest): s.save(j1) s.save(j2) + s.flush() s.clear() - j = s.query(Job).get_by(jobno='somejob') + j = s.query(Job).filter_by(jobno=u'somejob').one() oldp = list(j.pages) j.pages = [] s.flush() s.clear() - j = s.query(Job).get_by(jobno='somejob2') + j = s.query(Job).filter_by(jobno=u'somejob2').one() j.pages[1].current_version = 12 s.delete(j) s.flush() @@ -352,13 +368,13 @@ class RelationTest4(ORMTest): """test syncrules on foreign keys that are also primary""" def define_tables(self, metadata): global tableA, tableB - tableA = Table("A", metadata, + tableA = Table("A", metadata, Column("id",Integer,primary_key=True), Column("foo",Integer,), - ) + test_needs_fk=True) tableB = Table("B",metadata, Column("id",Integer,ForeignKey("A.id"),primary_key=True), - ) + test_needs_fk=True) def test_no_delete_PK_AtoB(self): """test that A cant be deleted without B because B would have no PK value""" class A(object):pass @@ -372,7 +388,7 @@ class RelationTest4(ORMTest): sess = create_session() sess.save(a1) sess.flush() - + sess.delete(a1) try: sess.flush() @@ -400,8 +416,30 @@ class RelationTest4(ORMTest): except exceptions.AssertionError, e: assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ") + @testing.fails_on_everything_except('sqlite', 'mysql') + def test_nullPKsOK_BtoA(self): + # postgres cant handle a nullable PK column...? + tableC = Table('tablec', tableA.metadata, + Column('id', Integer, primary_key=True), + Column('a_id', Integer, ForeignKey('A.id'), primary_key=True, autoincrement=False, nullable=True)) + tableC.create() + + class A(object):pass + class C(object):pass + mapper(C, tableC, properties={ + 'a':relation(A, cascade="save-update") + }, allow_null_pks=True) + mapper(A, tableA) + c1 = C() + c1.id = 5 + c1.a = None + sess = create_session() + sess.save(c1) + # test that no error is raised. + sess.flush() + def test_delete_cascade_BtoA(self): - """test that the 'blank the PK' error doesnt get raised when the child is to be deleted as part of a + """test that the 'blank the PK' error doesnt get raised when the child is to be deleted as part of a cascade""" class A(object):pass class B(object):pass @@ -426,9 +464,9 @@ class RelationTest4(ORMTest): assert b1 not in sess sess.clear() clear_mappers() - + def test_delete_cascade_AtoB(self): - """test that the 'blank the PK' error doesnt get raised when the child is to be deleted as part of a + """test that the 'blank the PK' error doesnt get raised when the child is to be deleted as part of a cascade""" class A(object):pass class B(object):pass @@ -446,14 +484,14 @@ class RelationTest4(ORMTest): sess = create_session() sess.save(a1) sess.flush() - + sess.delete(a1) sess.flush() assert a1 not in sess assert b1 not in sess sess.clear() clear_mappers() - + def test_delete_manual_AtoB(self): class A(object):pass class B(object):pass @@ -468,7 +506,7 @@ class RelationTest4(ORMTest): sess.save(a1) sess.save(b1) sess.flush() - + sess.delete(a1) sess.delete(b1) sess.flush() @@ -497,23 +535,24 @@ class RelationTest4(ORMTest): assert b1 not in sess class RelationTest5(ORMTest): - """test a map to a select that relates to a map to the table""" + """Test a map to a select that relates to a map to the table.""" + def define_tables(self, metadata): global items items = Table('items', metadata, Column('item_policy_num', String(10), primary_key=True, key='policyNum'), Column('item_policy_eff_date', Date, primary_key=True, key='policyEffDate'), Column('item_type', String(20), primary_key=True, key='type'), - Column('item_id', Integer, primary_key=True, key='id'), + Column('item_id', Integer, primary_key=True, key='id', autoincrement=False), ) def test_basic(self): class Container(object):pass class LineItem(object):pass - + container_select = select( [items.c.policyNum, items.c.policyEffDate, items.c.type], - distinct=True, + distinct=True, ).alias('container_select') mapper(LineItem, items) @@ -545,29 +584,29 @@ class RelationTest5(ORMTest): session.save(li) session.flush() session.clear() - newcon = session.query(Container).selectfirst() + newcon = session.query(Container).first() assert con.policyNum == newcon.policyNum assert len(newcon.lineItems) == 10 for old, new in zip(con.lineItems, newcon.lineItems): assert old.id == new.id - - + + class TypeMatchTest(ORMTest): """test errors raised when trying to add items whose type is not handled by a relation""" def define_tables(self, metadata): global a, b, c, d - a = Table("a", metadata, + a = Table("a", metadata, Column('aid', Integer, primary_key=True), Column('data', String(30))) - b = Table("b", metadata, + b = Table("b", metadata, Column('bid', Integer, primary_key=True), Column("a_id", Integer, ForeignKey("a.aid")), Column('data', String(30))) - c = Table("c", metadata, + c = Table("c", metadata, Column('cid', Integer, primary_key=True), Column("b_id", Integer, ForeignKey("b.bid")), Column('data', String(30))) - d = Table("d", metadata, + d = Table("d", metadata, Column('did', Integer, primary_key=True), Column("a_id", Integer, ForeignKey("a.aid")), Column('data', String(30))) @@ -578,7 +617,7 @@ class TypeMatchTest(ORMTest): mapper(A, a, properties={'bs':relation(B)}) mapper(B, b) mapper(C, c) - + a1 = A() b1 = B() c1 = C() @@ -597,7 +636,7 @@ class TypeMatchTest(ORMTest): mapper(A, a, properties={'bs':relation(B, cascade="none")}) mapper(B, b) mapper(C, c) - + a1 = A() b1 = B() c1 = C() @@ -619,7 +658,7 @@ class TypeMatchTest(ORMTest): mapper(A, a, properties={'bs':relation(B, cascade="none")}) mapper(B, b) mapper(C, c, inherits=B) - + a1 = A() b1 = B() c1 = C() @@ -634,7 +673,7 @@ class TypeMatchTest(ORMTest): assert False except exceptions.FlushError, err: assert str(err).startswith("Attempting to flush an item of type %s on collection 'A.bs (B)', which is handled by mapper 'Mapper|B|b' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ?" % C) - + def test_m2o_nopoly_onflush(self): class A(object):pass class B(A):pass @@ -673,18 +712,18 @@ class TypeMatchTest(ORMTest): class TypedAssociationTable(ORMTest): def define_tables(self, metadata): global t1, t2, t3 - - class MySpecialType(TypeDecorator): + + class MySpecialType(types.TypeDecorator): impl = String def convert_bind_param(self, value, dialect): return "lala" + value def convert_result_value(self, value, dialect): return value[4:] - - t1 = Table('t1', metadata, + + t1 = Table('t1', metadata, Column('col1', MySpecialType(30), primary_key=True), Column('col2', String(30))) - t2 = Table('t2', metadata, + t2 = Table('t2', metadata, Column('col1', MySpecialType(30), primary_key=True), Column('col2', String(30))) t3 = Table('t3', metadata, @@ -693,7 +732,7 @@ class TypedAssociationTable(ORMTest): ) def testm2m(self): """test many-to-many tables with special types for candidate keys""" - + class T1(object):pass class T2(object):pass mapper(T2, t2) @@ -713,293 +752,14 @@ class TypedAssociationTable(ORMTest): sess.flush() assert t3.count().scalar() == 2 - - a.t2s.remove(c) - sess.flush() - - assert t3.count().scalar() == 1 - -# TODO: move these tests to either attributes.py test or its own module -class CustomCollectionsTest(ORMTest): - def define_tables(self, metadata): - global sometable, someothertable - sometable = Table('sometable', metadata, - Column('col1',Integer, primary_key=True), - Column('data', String(30))) - someothertable = Table('someothertable', metadata, - Column('col1', Integer, primary_key=True), - Column('scol1', Integer, ForeignKey(sometable.c.col1)), - Column('data', String(20)) - ) - def testbasic(self): - class MyList(list): - pass - class Foo(object): - pass - class Bar(object): - pass - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, collection_class=MyList) - }) - mapper(Bar, someothertable) - f = Foo() - assert isinstance(f.bars, MyList) - def testlazyload(self): - """test that a 'set' can be used as a collection and can lazyload.""" - class Foo(object): - pass - class Bar(object): - pass - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, collection_class=set) - }) - mapper(Bar, someothertable) - f = Foo() - f.bars.add(Bar()) - f.bars.add(Bar()) - sess = create_session() - sess.save(f) - sess.flush() - sess.clear() - f = sess.query(Foo).get(f.col1) - assert len(list(f.bars)) == 2 - f.bars.clear() - - def testdict(self): - """test that a 'dict' can be used as a collection and can lazyload.""" - - class Foo(object): - pass - class Bar(object): - pass - class AppenderDict(dict): - @collection.appender - def set(self, item): - self[id(item)] = item - @collection.remover - def remove(self, item): - if id(item) in self: - del self[id(item)] - - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, collection_class=AppenderDict) - }) - mapper(Bar, someothertable) - f = Foo() - f.bars.set(Bar()) - f.bars.set(Bar()) - sess = create_session() - sess.save(f) - sess.flush() - sess.clear() - f = sess.query(Foo).get(f.col1) - assert len(list(f.bars)) == 2 - f.bars.clear() - - def testdictwrapper(self): - """test that the supplied 'dict' wrapper can be used as a collection and can lazyload.""" - - class Foo(object): - pass - class Bar(object): - def __init__(self, data): self.data = data - - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, - collection_class=collections.column_mapped_collection(someothertable.c.data)) - }) - mapper(Bar, someothertable) - - f = Foo() - col = collections.collection_adapter(f.bars) - col.append_with_event(Bar('a')) - col.append_with_event(Bar('b')) - sess = create_session() - sess.save(f) - sess.flush() - sess.clear() - f = sess.query(Foo).get(f.col1) - assert len(list(f.bars)) == 2 - - existing = set([id(b) for b in f.bars.values()]) - col = collections.collection_adapter(f.bars) - col.append_with_event(Bar('b')) - f.bars['a'] = Bar('a') + a.t2s.remove(c) sess.flush() - sess.clear() - f = sess.query(Foo).get(f.col1) - assert len(list(f.bars)) == 2 - - replaced = set([id(b) for b in f.bars.values()]) - self.assert_(existing != replaced) - - def testlist(self): - class Parent(object): - pass - class Child(object): - pass - - mapper(Parent, sometable, properties={ - 'children':relation(Child, collection_class=list) - }) - mapper(Child, someothertable) - - control = list() - p = Parent() - - o = Child() - control.append(o) - p.children.append(o) - assert control == p.children - assert control == list(p.children) - - o = [Child(), Child(), Child(), Child()] - control.extend(o) - p.children.extend(o) - assert control == p.children - assert control == list(p.children) - - assert control[0] == p.children[0] - assert control[-1] == p.children[-1] - assert control[1:3] == p.children[1:3] - - del control[1] - del p.children[1] - assert control == p.children - assert control == list(p.children) - - o = [Child()] - control[1:3] = o - p.children[1:3] = o - assert control == p.children - assert control == list(p.children) - - o = [Child(), Child(), Child(), Child()] - control[1:3] = o - p.children[1:3] = o - assert control == p.children - assert control == list(p.children) - - o = [Child(), Child(), Child(), Child()] - control[-1:-2] = o - p.children[-1:-2] = o - assert control == p.children - assert control == list(p.children) - - o = [Child(), Child(), Child(), Child()] - control[4:] = o - p.children[4:] = o - assert control == p.children - assert control == list(p.children) - - o = Child() - control.insert(0, o) - p.children.insert(0, o) - assert control == p.children - assert control == list(p.children) - - o = Child() - control.insert(3, o) - p.children.insert(3, o) - assert control == p.children - assert control == list(p.children) - - o = Child() - control.insert(999, o) - p.children.insert(999, o) - assert control == p.children - assert control == list(p.children) - - del control[0:1] - del p.children[0:1] - assert control == p.children - assert control == list(p.children) - - del control[1:1] - del p.children[1:1] - assert control == p.children - assert control == list(p.children) - - del control[1:3] - del p.children[1:3] - assert control == p.children - assert control == list(p.children) - - del control[7:] - del p.children[7:] - assert control == p.children - assert control == list(p.children) - - assert control.pop() == p.children.pop() - assert control == p.children - assert control == list(p.children) - - assert control.pop(0) == p.children.pop(0) - assert control == p.children - assert control == list(p.children) - - assert control.pop(2) == p.children.pop(2) - assert control == p.children - assert control == list(p.children) - - o = Child() - control.insert(2, o) - p.children.insert(2, o) - assert control == p.children - assert control == list(p.children) - - control.remove(o) - p.children.remove(o) - assert control == p.children - assert control == list(p.children) - - def testobj(self): - class Parent(object): - pass - class Child(object): - pass - class MyCollection(object): - def __init__(self): self.data = [] - @collection.appender - def append(self, value): self.data.append(value) - @collection.remover - def remove(self, value): self.data.remove(value) - @collection.iterator - def __iter__(self): return iter(self.data) - - mapper(Parent, sometable, properties={ - 'children':relation(Child, collection_class=MyCollection) - }) - mapper(Child, someothertable) - - control = list() - p1 = Parent() - - o = Child() - control.append(o) - p1.children.append(o) - assert control == list(p1.children) - - o = Child() - control.append(o) - p1.children.append(o) - assert control == list(p1.children) - - o = Child() - control.append(o) - p1.children.append(o) - assert control == list(p1.children) + assert t3.count().scalar() == 1 - sess = create_session() - sess.save(p1) - sess.flush() - sess.clear() - p2 = sess.query(Parent).get(p1.col1) - o = list(p2.children) - assert len(o) == 3 + class ViewOnlyTest(ORMTest): """test a view_only mapping where a third table is pulled into the primary join condition, @@ -1018,12 +778,12 @@ class ViewOnlyTest(ORMTest): Column('data', String(40)), Column('t2id', Integer, ForeignKey('t2.id')) ) - + def test_basic(self): class C1(object):pass class C2(object):pass class C3(object):pass - + mapper(C1, t1, properties={ 't2s':relation(C2), 't2_view':relation(C2, viewonly=True, primaryjoin=and_(t1.c.id==t2.c.t1id, t3.c.t2id==t2.c.id, t3.c.data==t1.c.data)) @@ -1032,7 +792,7 @@ class ViewOnlyTest(ORMTest): mapper(C3, t3, properties={ 't2':relation(C2) }) - + c1 = C1() c1.data = 'c1data' c2a = C2() @@ -1047,7 +807,7 @@ class ViewOnlyTest(ORMTest): sess.save(c3) sess.flush() sess.clear() - + c1 = sess.query(C1).get(c1.id) assert set([x.id for x in c1.t2s]) == set([c2a.id, c2b.id]) assert set([x.id for x in c1.t2_view]) == set([c2b.id]) @@ -1101,7 +861,520 @@ class ViewOnlyTest2(ORMTest): c1 = sess.query(C1).get(c1.t1id) assert set([x.t2id for x in c1.t2s]) == set([c2a.t2id, c2b.t2id]) assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id]) + +class ViewOnlyTest3(ORMTest): + """test relating on a join that has no equated columns""" + def define_tables(self, metadata): + global foos, bars + foos = Table('foos', metadata, Column('id', Integer, primary_key=True)) + bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer)) + + def test_viewonly_join(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, foreign_keys=[bars.c.fid], viewonly=True) + }) + + mapper(Bar, bars) + + sess = create_session() + sess.save(Foo(id=4)) + sess.save(Foo(id=9)) + sess.save(Bar(id=1, fid=2)) + sess.save(Bar(id=2, fid=3)) + sess.save(Bar(id=3, fid=6)) + sess.save(Bar(id=4, fid=7)) + sess.flush() + + sess = create_session() + self.assertEquals(sess.query(Foo).filter_by(id=4).one(), Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)])) + self.assertEquals(sess.query(Foo).filter_by(id=9).one(), Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)])) + +class ViewOnlyTest4(ORMTest): + """test relating on a join that contains the same 'remote' column twice""" + def define_tables(self, metadata): + global foos, bars + foos = Table('foos', metadata, Column('id', Integer, primary_key=True), + Column('bid1', Integer,ForeignKey('bars.id')), + Column('bid2', Integer,ForeignKey('bars.id'))) + + bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('data', String(50))) + + def test_relation_on_or(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=or_(bars.c.id==foos.c.bid1, bars.c.id==foos.c.bid2), uselist=True, viewonly=True) + }) + + mapper(Bar, bars) + sess = create_session() + b1 = Bar(id=1, data='b1') + b2 = Bar(id=2, data='b2') + b3 = Bar(id=3, data='b3') + f1 = Foo(bid1=1, bid2=2) + f2 = Foo(bid1=3, bid2=None) + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.flush() + sess.save(f1) + sess.save(f2) + sess.flush() + + sess.clear() + self.assertEquals(sess.query(Foo).filter_by(id=f1.id).one(), Foo(bars=[Bar(data='b1'), Bar(data='b2')])) + self.assertEquals(sess.query(Foo).filter_by(id=f2.id).one(), Foo(bars=[Bar(data='b3')])) + +class ViewOnlyTest5(ORMTest): + """test relating on a join that contains the same 'local' column twice""" + def define_tables(self, metadata): + global foos, bars + foos = Table('foos', metadata, Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + bars = Table('bars', metadata, Column('id', Integer, primary_key=True), + Column('fid1', Integer, ForeignKey('foos.id')), + Column('fid2', Integer, ForeignKey('foos.id')), + Column('data', String(50)) + ) + + def test_relation_on_or(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=or_(bars.c.fid1==foos.c.id, bars.c.fid2==foos.c.id), viewonly=True) + }) + + mapper(Bar, bars) + sess = create_session() + f1 = Foo(id=1, data='f1') + f2 = Foo(id=2, data='f2') + b1 = Bar(fid1=1, data='b1') + b2 = Bar(fid2=1, data='b2') + b3 = Bar(fid1=2, data='b3') + b4 = Bar(fid1=1, fid2=2, data='b4') + sess.save(f1) + sess.save(f2) + sess.flush() + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.save(b4) + sess.flush() + + sess.clear() + self.assertEquals(sess.query(Foo).filter_by(id=f1.id).one(), Foo(bars=[Bar(data='b1'), Bar(data='b2'), Bar(data='b4')])) + self.assertEquals(sess.query(Foo).filter_by(id=f2.id).one(), Foo(bars=[Bar(data='b3'), Bar(data='b4')])) + +class ViewOnlyTest6(ORMTest): + """test a long primaryjoin condition""" + def define_tables(self, metadata): + global t1, t2, t3, t2tot3 + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t1id', Integer, ForeignKey('t1.id')), + ) + t3 = Table('t3', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + t2tot3 = Table('t2tot3', metadata, + Column('t2id', Integer, ForeignKey('t2.id')), + Column('t3id', Integer, ForeignKey('t3.id')), + ) + + def test_basic(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + class T3(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't3s':relation(T3, primaryjoin=and_( + t1.c.id==t2.c.t1id, + t2.c.id==t2tot3.c.t2id, + t3.c.id==t2tot3.c.t3id + ), + viewonly=True, + foreign_keys=t3.c.id, remote_side=t2.c.t1id) + }) + mapper(T2, t2, properties={ + 't1':relation(T1), + 't3s':relation(T3, secondary=t2tot3) + }) + mapper(T3, t3) + sess = create_session() + sess.save(T2(data='t2', t1=T1(data='t1'), t3s=[T3(data='t3')])) + sess.flush() + sess.clear() + + a = sess.query(T1).first() + self.assertEquals(a.t3s, [T3(data='t3')]) + def test_remote_side_escalation(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + class T3(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't3s':relation(T3, primaryjoin=and_( + t1.c.id==t2.c.t1id, + t2.c.id==t2tot3.c.t2id, + t3.c.id==t2tot3.c.t3id + ),viewonly=True, foreign_keys=t3.c.id) + }) + mapper(T2, t2, properties={ + 't1':relation(T1), + 't3s':relation(T3, secondary=t2tot3) + }) + mapper(T3, t3) + self.assertRaisesMessage(exceptions.ArgumentError, "Specify remote_side argument", compile_mappers) + +class ExplicitLocalRemoteTest(ORMTest): + def define_tables(self, metadata): + global t1, t2 + t1 = Table('t1', metadata, + Column('id', String(50), primary_key=True), + Column('data', String(50)) + ) + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t1id', String(50)), + ) + + def test_onetomany_funcfk(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + # use a function within join condition. but specifying + # local_remote_pairs overrides all parsing of the join condition. + mapper(T1, t1, properties={ + 't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id] + ) + }) + mapper(T2, t2) + + sess = create_session() + a1 = T1(id='number1', data='a1') + a2 = T1(id='number2', data='a2') + b1 = T2(data='b1', t1id='NuMbEr1') + b2 = T2(data='b2', t1id='Number1') + b3 = T2(data='b3', t1id='Number2') + sess.save(a1) + sess.save(a2) + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(T1).first(), T1(id='number1', data='a1', t2s=[T2(data='b1', t1id='NuMbEr1'), T2(data='b2', t1id='Number1')])) + + def test_manytoone_funcfk(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + mapper(T1, t1) + mapper(T2, t2, properties={ + 't1':relation(T1, primaryjoin=t1.c.id==func.lower(t2.c.t1id), + _local_remote_pairs=[(t2.c.t1id, t1.c.id)], + foreign_keys=[t2.c.t1id], + uselist=True + ) + }) + sess = create_session() + a1 = T1(id='number1', data='a1') + a2 = T1(id='number2', data='a2') + b1 = T2(data='b1', t1id='NuMbEr1') + b2 = T2(data='b2', t1id='Number1') + b3 = T2(data='b3', t1id='Number2') + sess.save(a1) + sess.save(a2) + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.flush() + sess.clear() + self.assertEquals(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), + [ + T2(data='b1', t1=[T1(id='number1', data='a1')]), + T2(data='b2', t1=[T1(id='number1', data='a1')]) + ] + ) + + def test_onetomany_func_referent(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't2s':relation(T2, primaryjoin=func.lower(t1.c.id)==t2.c.t1id, + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id] + ) + }) + mapper(T2, t2) + + sess = create_session() + a1 = T1(id='NuMbeR1', data='a1') + a2 = T1(id='NuMbeR2', data='a2') + b1 = T2(data='b1', t1id='number1') + b2 = T2(data='b2', t1id='number1') + b3 = T2(data='b2', t1id='number2') + sess.save(a1) + sess.save(a2) + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(T1).first(), T1(id='NuMbeR1', data='a1', t2s=[T2(data='b1', t1id='number1'), T2(data='b2', t1id='number1')])) + + def test_manytoone_func_referent(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + mapper(T1, t1) + mapper(T2, t2, properties={ + 't1':relation(T1, primaryjoin=func.lower(t1.c.id)==t2.c.t1id, + _local_remote_pairs=[(t2.c.t1id, t1.c.id)], + foreign_keys=[t2.c.t1id], uselist=True + ) + }) + + sess = create_session() + a1 = T1(id='NuMbeR1', data='a1') + a2 = T1(id='NuMbeR2', data='a2') + b1 = T2(data='b1', t1id='number1') + b2 = T2(data='b2', t1id='number1') + b3 = T2(data='b3', t1id='number2') + sess.save(a1) + sess.save(a2) + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), + [ + T2(data='b1', t1=[T1(id='NuMbeR1', data='a1')]), + T2(data='b2', t1=[T1(id='NuMbeR1', data='a1')]) + ] + ) + + def test_escalation(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id], + remote_side=[t2.c.t1id] + ) + }) + mapper(T2, t2) + self.assertRaises(exceptions.ArgumentError, compile_mappers) + + clear_mappers() + mapper(T1, t1, properties={ + 't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + ) + }) + mapper(T2, t2) + self.assertRaises(exceptions.ArgumentError, compile_mappers) + +class InvalidRelationEscalationTest(ORMTest): + def define_tables(self, metadata): + global foos, bars, Foo, Bar + foos = Table('foos', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer)) + bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer)) + class Foo(object): + pass + class Bar(object): + pass + + def test_no_join(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers) + + def test_no_join_self_ref(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers) + + def test_no_equated(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers) + + def test_no_equated_fks(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, foreign_keys=bars.c.fid) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers) + + def test_no_equated_self_ref(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers) + + def test_no_equated_self_ref(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, foreign_keys=[foos.c.fid]) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers) + + def test_no_equated_viewonly(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, viewonly=True) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers) + + def test_no_equated_self_ref_viewonly(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, viewonly=True) + }) + + mapper(Bar, bars) + + self.assertRaisesMessage(exceptions.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers) + + def test_no_equated_self_ref_viewonly_fks(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, viewonly=True, foreign_keys=[foos.c.fid]) + }) + compile_mappers() + self.assertEquals(Foo.foos.property.local_remote_pairs, [(foos.c.id, foos.c.fid)]) + + def test_equated(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=foos.c.id==bars.c.fid) + }) + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers) + + def test_equated_self_ref(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid) + }) + + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers) + + def test_equated_self_ref_wrong_fks(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id]) + }) + + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers) + +class InvalidRelationEscalationTestM2M(ORMTest): + def define_tables(self, metadata): + global foos, bars, Foo, Bar, foobars + foos = Table('foos', metadata, Column('id', Integer, primary_key=True)) + foobars = Table('foobars', metadata, Column('fid', Integer), Column('bid', Integer)) + bars = Table('bars', metadata, Column('id', Integer, primary_key=True)) + class Foo(object): + pass + class Bar(object): + pass + + def test_no_join(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, secondary=foobars) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers) + + def test_no_secondaryjoin(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id>foobars.c.fid) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers) + + def test_bad_primaryjoin(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id>foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers) + + def test_bad_secondaryjoin(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id==foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id, foreign_keys=[foobars.c.fid]) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers) + + def test_no_equated_secondaryjoin(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id==foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id, foreign_keys=[foobars.c.fid, foobars.c.bid]) + }) + + mapper(Bar, bars) + self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for secondaryjoin condition", compile_mappers) + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/selectable.py b/test/orm/selectable.py new file mode 100644 index 0000000000..fc5be6f505 --- /dev/null +++ b/test/orm/selectable.py @@ -0,0 +1,52 @@ +"""all tests involving generic mapping to Select statements""" + +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy import exceptions +from sqlalchemy.orm import * +from testlib import * +from testlib.fixtures import * +from query import QueryTest + +class SelectableNoFromsTest(ORMTest): + def define_tables(self, metadata): + global common_table + common_table = Table('common', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer), + Column('extra', String(45)), + ) + + def test_no_tables(self): + class Subset(object): + pass + selectable = select(["x", "y", "z"]) + self.assertRaisesMessage(exceptions.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable) + + @testing.emits_warning('.*creating an Alias.*') + def test_basic(self): + class Subset(Base): + pass + + subset_select = select([common_table.c.id, common_table.c.data]) + subset_mapper = mapper(Subset, subset_select) + + sess = create_session(bind=testing.db) + l = Subset() + l.data = 1 + sess.save(l) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(Subset).all(), [Subset(data=1)]) + self.assertEquals(sess.query(Subset).filter(Subset.data==1).one(), Subset(data=1)) + self.assertEquals(sess.query(Subset).filter(Subset.data!=1).first(), None) + + subset_select = class_mapper(Subset).mapped_table + self.assertEquals(sess.query(Subset).filter(subset_select.c.data==1).one(), Subset(data=1)) + + + # TODO: more tests mapping to selects + +if __name__ == '__main__': + testenv.main() diff --git a/test/orm/session.py b/test/orm/session.py index 4332796737..49932f8d9d 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -1,24 +1,34 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions, util from sqlalchemy.orm import * +from sqlalchemy.orm.session import SessionExtension +from sqlalchemy.orm.session import Session as SessionCls from testlib import * from testlib.tables import * -import testlib.tables as tables +from testlib import fixtures, tables +import pickle +import gc -class SessionTest(AssertMixin): + +class SessionTest(TestBase, AssertsExecutionResults): def setUpAll(self): tables.create() + def tearDownAll(self): tables.drop() + def tearDown(self): + SessionCls.close_all() tables.delete() clear_mappers() + def setUp(self): pass def test_close(self): """test that flush() doenst close a connection the session didnt open""" - c = testbase.db.connect() + c = testing.db.connect() class User(object):pass mapper(User, users) s = create_session(bind=c) @@ -35,7 +45,7 @@ class SessionTest(AssertMixin): s.flush() def test_close_two(self): - c = testbase.db.connect() + c = testing.db.connect() try: class User(object):pass mapper(User, users) @@ -72,14 +82,47 @@ class SessionTest(AssertMixin): # then see if expunge fails session.expunge(u) - - @testing.unsupported('sqlite') + + @engines.close_open_connections + def test_binds_from_expression(self): + """test that Session can extract Table objects from ClauseElements and match them to tables.""" + Session = sessionmaker(binds={users:testing.db, addresses:testing.db}) + sess = Session() + sess.execute(users.insert(), params=dict(user_id=1, user_name='ed')) + assert sess.execute(users.select()).fetchall() == [(1, 'ed')] + + mapper(Address, addresses) + mapper(User, users, properties={ + 'addresses':relation(Address, backref=backref("user", cascade="all"), cascade="all") + }) + Session = sessionmaker(binds={User:testing.db, Address:testing.db}) + sess.execute(users.insert(), params=dict(user_id=2, user_name='fred')) + assert sess.execute(users.select()).fetchall() == [(1, 'ed'), (2, 'fred')] + sess.close() + + @engines.close_open_connections + def test_bind_from_metadata(self): + Session = sessionmaker() + sess = Session() + mapper(User, users) + + sess.execute(users.insert(), dict(user_name='Johnny')) + + assert len(sess.query(User).all()) == 1 + + sess.execute(users.delete()) + + assert len(sess.query(User).all()) == 0 + sess.close() + + @testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang + @engines.close_open_connections def test_transaction(self): class User(object):pass mapper(User, users) - conn1 = testbase.db.connect() - conn2 = testbase.db.connect() - + conn1 = testing.db.connect() + conn2 = testing.db.connect() + sess = create_session(transactional=True, bind=conn1) u = User() sess.save(u) @@ -88,16 +131,32 @@ class SessionTest(AssertMixin): assert conn2.execute("select count(1) from users").scalar() == 0 sess.commit() assert conn1.execute("select count(1) from users").scalar() == 1 - assert testbase.db.connect().execute("select count(1) from users").scalar() == 1 - - @testing.unsupported('sqlite') + assert testing.db.connect().execute("select count(1) from users").scalar() == 1 + sess.close() + + def test_flush_noop(self): + session = create_session() + session.uow = object() + + self.assertRaises(AttributeError, session.flush) + + session = create_session() + session.uow = object() + + session.flush(objects=[]) + session.flush(objects=set()) + session.flush(objects=()) + session.flush(objects=iter([])) + + @testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang + @engines.close_open_connections def test_autoflush(self): class User(object):pass mapper(User, users) - conn1 = testbase.db.connect() - conn2 = testbase.db.connect() - - sess = create_session(autoflush=True, bind=conn1) + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + sess = create_session(bind=conn1, transactional=True, autoflush=True) u = User() u.user_name='ed' sess.save(u) @@ -107,82 +166,155 @@ class SessionTest(AssertMixin): assert conn2.execute("select count(1) from users").scalar() == 0 sess.commit() assert conn1.execute("select count(1) from users").scalar() == 1 - assert testbase.db.connect().execute("select count(1) from users").scalar() == 1 + assert testing.db.connect().execute("select count(1) from users").scalar() == 1 + sess.close() + + def test_autoflush_expressions(self): + class User(fixtures.Base): + pass + class Address(fixtures.Base): + pass + mapper(User, users, properties={ + 'addresses':relation(Address, backref="user") + }) + mapper(Address, addresses) + + sess = create_session(autoflush=True, transactional=True) + u = User(user_name='ed', addresses=[Address(email_address='foo')]) + sess.save(u) + self.assertEquals(sess.query(Address).filter(Address.user==u).one(), Address(email_address='foo')) - @testing.unsupported('sqlite') + @testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang + @engines.close_open_connections def test_autoflush_unbound(self): class User(object):pass mapper(User, users) try: - sess = create_session(autoflush=True) + sess = create_session(transactional=True, autoflush=True) u = User() u.user_name='ed' sess.save(u) u2 = sess.query(User).filter_by(user_name='ed').one() assert u2 is u assert sess.execute("select count(1) from users", mapper=User).scalar() == 1 - assert testbase.db.connect().execute("select count(1) from users").scalar() == 0 + assert testing.db.connect().execute("select count(1) from users").scalar() == 0 sess.commit() assert sess.execute("select count(1) from users", mapper=User).scalar() == 1 - assert testbase.db.connect().execute("select count(1) from users").scalar() == 1 + assert testing.db.connect().execute("select count(1) from users").scalar() == 1 + sess.close() except: sess.rollback() raise - + + @engines.close_open_connections def test_autoflush_2(self): class User(object):pass mapper(User, users) - conn1 = testbase.db.connect() - conn2 = testbase.db.connect() - - sess = create_session(autoflush=True, bind=conn1) + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + sess = create_session(bind=conn1, transactional=True, autoflush=True) u = User() u.user_name='ed' sess.save(u) sess.commit() assert conn1.execute("select count(1) from users").scalar() == 1 - assert testbase.db.connect().execute("select count(1) from users").scalar() == 1 - + assert testing.db.connect().execute("select count(1) from users").scalar() == 1 + sess.commit() + + # TODO: not doing rollback of attributes right now. + def dont_test_autoflush_rollback(self): + tables.data() + mapper(Address, addresses) + mapper(User, users, properties={ + 'addresses':relation(Address) + }) + + sess = create_session(transactional=True, autoflush=True) + u = sess.query(User).get(8) + newad = Address() + newad.email_address == 'something new' + u.addresses.append(newad) + u.user_name = 'some new name' + assert u.user_name == 'some new name' + assert len(u.addresses) == 4 + assert newad in u.addresses + sess.rollback() + assert u.user_name == 'ed' + assert len(u.addresses) == 3 + assert newad not in u.addresses + + + @engines.close_open_connections def test_external_joined_transaction(self): class User(object):pass mapper(User, users) - conn = testbase.db.connect() + conn = testing.db.connect() trans = conn.begin() - sess = create_session(bind=conn) - sess.begin() + sess = create_session(bind=conn, transactional=True, autoflush=True) + sess.begin() u = User() sess.save(u) sess.flush() sess.commit() # commit does nothing trans.rollback() # rolls back - assert len(sess.query(User).select()) == 0 + assert len(sess.query(User).all()) == 0 + sess.close() - @testing.supported('postgres', 'mysql') + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @engines.close_open_connections def test_external_nested_transaction(self): class User(object):pass mapper(User, users) try: - conn = testbase.db.connect() + conn = testing.db.connect() trans = conn.begin() - sess = create_session(bind=conn) + sess = create_session(bind=conn, transactional=True, autoflush=True) u1 = User() sess.save(u1) sess.flush() - - sess.begin_nested() + + sess.begin_nested() u2 = User() sess.save(u2) sess.flush() sess.rollback() - - trans.commit() - assert len(sess.query(User).select()) == 1 + + trans.commit() + assert len(sess.query(User).all()) == 1 except: conn.close() raise - - @testing.supported('postgres', 'mysql') + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @engines.close_open_connections + def test_heavy_nesting(self): + session = create_session(bind=testing.db) + + session.begin() + session.connection().execute("insert into users (user_name) values ('user1')") + + session.begin() + + session.begin_nested() + + session.connection().execute("insert into users (user_name) values ('user2')") + assert session.connection().execute("select count(1) from users").scalar() == 2 + + session.rollback() + assert session.connection().execute("select count(1) from users").scalar() == 1 + session.connection().execute("insert into users (user_name) values ('user3')") + + session.commit() + assert session.connection().execute("select count(1) from users").scalar() == 2 + + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) def test_twophase(self): # TODO: mock up a failure condition here # to ensure a rollback succeeds @@ -190,10 +322,10 @@ class SessionTest(AssertMixin): class Address(object):pass mapper(User, users) mapper(Address, addresses) - - engine2 = create_engine(testbase.db.url) - sess = create_session(twophase=True) - sess.bind_mapper(User, testbase.db) + + engine2 = create_engine(testing.db.url) + sess = create_session(transactional=False, autoflush=False, twophase=True) + sess.bind_mapper(User, testing.db) sess.bind_mapper(Address, engine2) sess.begin() u1 = User() @@ -205,23 +337,23 @@ class SessionTest(AssertMixin): engine2.dispose() assert users.count().scalar() == 1 assert addresses.count().scalar() == 1 - - - + def test_joined_transaction(self): class User(object):pass mapper(User, users) - sess = create_session() + sess = create_session(transactional=True, autoflush=True) sess.begin() - sess.begin() u = User() sess.save(u) sess.flush() sess.commit() # commit does nothing sess.rollback() # rolls back - assert len(sess.query(User).select()) == 0 + assert len(sess.query(User).all()) == 0 + sess.close() - @testing.supported('postgres', 'mysql') + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) def test_nested_transaction(self): class User(object):pass mapper(User, users) @@ -238,12 +370,15 @@ class SessionTest(AssertMixin): sess.save(u2) sess.flush() - sess.rollback() - + sess.rollback() + sess.commit() - assert len(sess.query(User).select()) == 1 + assert len(sess.query(User).all()) == 1 + sess.close() - @testing.supported('postgres', 'mysql') + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) def test_nested_autotrans(self): class User(object):pass mapper(User, users) @@ -258,58 +393,227 @@ class SessionTest(AssertMixin): sess.save(u2) sess.flush() - sess.rollback() + sess.rollback() sess.commit() - assert len(sess.query(User).select()) == 1 + assert len(sess.query(User).all()) == 1 + sess.close() + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) + def test_nested_transaction_connection_add(self): + class User(object): pass + mapper(User, users) + + sess = create_session(transactional=False) + + sess.begin() + sess.begin_nested() + + u1 = User() + sess.save(u1) + sess.flush() + + sess.rollback() + + u2 = User() + sess.save(u2) + sess.commit() + + self.assertEquals(util.Set(sess.query(User).all()), util.Set([u2])) + + sess.begin() + sess.begin_nested() + + u3 = User() + sess.save(u3) + sess.commit() # commit the nested transaction + sess.rollback() + + self.assertEquals(util.Set(sess.query(User).all()), util.Set([u2])) + + sess.close() + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) + def test_mixed_transaction_control(self): + class User(object): pass + mapper(User, users) + + sess = create_session(transactional=False) + + sess.begin() + sess.begin_nested() + transaction = sess.begin() + + sess.save(User()) + + transaction.commit() + sess.commit() + sess.commit() + + sess.close() + + self.assertEquals(len(sess.query(User).all()), 1) + + t1 = sess.begin() + t2 = sess.begin_nested() + + sess.save(User()) + + t2.commit() + assert sess.transaction is t1 + + sess.close() + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) + def test_mixed_transaction_close(self): + class User(object): pass + mapper(User, users) + + sess = create_session(transactional=True) + + sess.begin_nested() + + sess.save(User()) + sess.flush() + + sess.close() + + sess.save(User()) + sess.commit() + + sess.close() + + self.assertEquals(len(sess.query(User).all()), 1) + + @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access', + 'oracle', 'maxdb') + @testing.exclude('mysql', '<', (5, 0, 3)) + def test_error_on_using_inactive_session(self): + class User(object): pass + mapper(User, users) + + sess = create_session(transactional=False) + + try: + sess.begin() + sess.begin() + + sess.save(User()) + sess.flush() + + sess.rollback() + sess.begin() + assert False + except exceptions.InvalidRequestError, e: + self.assertEquals(str(e), "The transaction is inactive due to a rollback in a subtransaction and should be closed") + sess.close() + + @engines.close_open_connections def test_bound_connection(self): class User(object):pass mapper(User, users) - c = testbase.db.connect() + c = testing.db.connect() sess = create_session(bind=c) sess.create_transaction() transaction = sess.transaction u = User() sess.save(u) sess.flush() - assert transaction.get_or_add(testbase.db) is transaction.get_or_add(c) is c - + assert transaction.get_or_add(testing.db) is transaction.get_or_add(c) is c + try: - transaction.add(testbase.db.connect()) + transaction.add(testing.db.connect()) assert False - except exceptions.InvalidRequestError, e: + except exceptions.InvalidRequestError, e: assert str(e) == "Session already has a Connection associated for the given Connection's Engine" try: - transaction.get_or_add(testbase.db.connect()) + transaction.get_or_add(testing.db.connect()) assert False - except exceptions.InvalidRequestError, e: + except exceptions.InvalidRequestError, e: assert str(e) == "Session already has a Connection associated for the given Connection's Engine" try: - transaction.add(testbase.db) + transaction.add(testing.db) assert False - except exceptions.InvalidRequestError, e: + except exceptions.InvalidRequestError, e: assert str(e) == "Session already has a Connection associated for the given Engine" - + transaction.rollback() - assert len(sess.query(User).select()) == 0 - - def test_update(self): - """test that the update() method functions and doesnet blow away changes""" - tables.delete() - s = create_session() + assert len(sess.query(User).all()) == 0 + sess.close() + + def test_bound_connection_transactional(self): class User(object):pass mapper(User, users) - - # save user - s.save(User()) + c = testing.db.connect() + + sess = create_session(bind=c, transactional=True) + u = User() + sess.save(u) + sess.flush() + sess.close() + assert not c.in_transaction() + assert c.scalar("select count(1) from users") == 0 + + sess = create_session(bind=c, transactional=True) + u = User() + sess.save(u) + sess.flush() + sess.commit() + assert not c.in_transaction() + assert c.scalar("select count(1) from users") == 1 + c.execute("delete from users") + assert c.scalar("select count(1) from users") == 0 + + c = testing.db.connect() + + trans = c.begin() + sess = create_session(bind=c, transactional=False) + u = User() + sess.save(u) + sess.flush() + assert c.in_transaction() + trans.commit() + assert not c.in_transaction() + assert c.scalar("select count(1) from users") == 1 + + + @engines.close_open_connections + def test_save_update_delete(self): + + s = create_session() + class User(object): + pass + mapper(User, users) + + user = User() + + try: + s.update(user) + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user)) + + try: + s.delete(user) + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user)) + + s.save(user) s.flush() - user = s.query(User).selectone() + user = s.query(User).one() s.expunge(user) assert user not in s - + # modify outside of session, assert changes remain/get saved user.user_name = "fred" s.update(user) @@ -317,33 +621,177 @@ class SessionTest(AssertMixin): assert user in s.dirty s.flush() s.clear() - user = s.query(User).selectone() + assert s.query(User).count() == 1 + user = s.query(User).one() assert user.user_name == 'fred' - + # ensure its not dirty if no changes occur s.clear() assert user not in s s.update(user) assert user in s assert user not in s.dirty - - def test_strong_ref(self): - """test that the session is strong-referencing""" - tables.delete() + + try: + s.save(user) + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Instance 'User@%s' is already persistent" % hex(id(user)) + + s2 = create_session() + try: + s2.delete(user) + assert False + except exceptions.InvalidRequestError, e: + assert "is already attached to session" in str(e) + + u2 = s2.query(User).get(user.user_id) + try: + s.delete(u2) + assert False + except exceptions.InvalidRequestError, e: + assert "already persisted with a different identity" in str(e) + + s.delete(user) + s.flush() + assert user not in s + assert s.query(User).count() == 0 + + def test_is_modified(self): s = create_session() class User(object):pass + class Address(object):pass + + mapper(User, users, properties={'addresses':relation(Address)}) + mapper(Address, addresses) + + # save user + u = User() + u.user_name = 'fred' + s.save(u) + s.flush() + s.clear() + + user = s.query(User).one() + assert user not in s.dirty + assert not s.is_modified(user) + user.user_name = 'fred' + assert user in s.dirty + assert not s.is_modified(user) + user.user_name = 'ed' + assert user in s.dirty + assert s.is_modified(user) + s.flush() + assert user not in s.dirty + assert not s.is_modified(user) + + a = Address() + user.addresses.append(a) + assert user in s.dirty + assert s.is_modified(user) + assert not s.is_modified(user, include_collections=False) + + + def test_weak_ref(self): + """test the weak-referencing identity map, which strongly-references modified items.""" + + s = create_session() + class User(fixtures.Base):pass mapper(User, users) - + + s.save(User(user_name='ed')) + s.flush() + assert not s.dirty + + user = s.query(User).one() + del user + gc.collect() + assert len(s.identity_map) == 0 + assert len(s.identity_map.data) == 0 + + user = s.query(User).one() + user.user_name = 'fred' + del user + gc.collect() + assert len(s.identity_map) == 1 + assert len(s.identity_map.data) == 1 + assert len(s.dirty) == 1 + + s.flush() + gc.collect() + assert not s.dirty + assert not s.identity_map + assert not s.identity_map.data + + user = s.query(User).one() + assert user.user_name == 'fred' + assert s.identity_map + + def test_strong_ref(self): + s = create_session(weak_identity_map=False) + class User(object):pass + mapper(User, users) + # save user s.save(User()) s.flush() - user = s.query(User).selectone() + user = s.query(User).one() user = None print s.identity_map import gc gc.collect() assert len(s.identity_map) == 1 - + + def test_prune(self): + s = create_session(weak_identity_map=False) + class User(object):pass + mapper(User, users) + + for o in [User() for x in xrange(10)]: + s.save(o) + # o is still live after this loop... + + self.assert_(len(s.identity_map) == 0) + self.assert_(s.prune() == 0) + s.flush() + import gc + gc.collect() + self.assert_(s.prune() == 9) + self.assert_(len(s.identity_map) == 1) + + user_id = o.user_id + del o + self.assert_(s.prune() == 1) + self.assert_(len(s.identity_map) == 0) + + u = s.query(User).get(user_id) + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 1) + u.user_name = 'squiznart' + del u + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 1) + s.flush() + self.assert_(s.prune() == 1) + self.assert_(len(s.identity_map) == 0) + + s.save(User()) + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 0) + s.flush() + self.assert_(len(s.identity_map) == 1) + self.assert_(s.prune() == 1) + self.assert_(len(s.identity_map) == 0) + + u = s.query(User).get(user_id) + s.delete(u) + del u + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 1) + s.flush() + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 0) + def test_no_save_cascade(self): mapper(Address, addresses) mapper(User, users, properties=dict( @@ -357,18 +805,19 @@ class SessionTest(AssertMixin): assert u in s assert a not in s s.flush() + print "\n".join([repr(x.__dict__) for x in s]) s.clear() - assert s.query(User).selectone().user_id == u.user_id - assert s.query(Address).selectfirst() is None - + assert s.query(User).one().user_id == u.user_id + assert s.query(Address).first() is None + clear_mappers() - + tables.delete() mapper(Address, addresses) mapper(User, users, properties=dict( addresses=relation(Address, cascade="all", backref=backref("user", cascade="none")) )) - + s = create_session() u = User() a = Address() @@ -378,8 +827,8 @@ class SessionTest(AssertMixin): assert a in s s.flush() s.clear() - assert s.query(Address).selectone().address_id == a.address_id - assert s.query(User).selectfirst() is None + assert s.query(Address).one().address_id == a.address_id + assert s.query(User).first() is None def _assert_key(self, got, expect): assert got == expect, "expected %r got %r" % (expect, got) @@ -415,7 +864,315 @@ class SessionTest(AssertMixin): self._assert_key(key, (User, (1,), None)) key = s.identity_key(User, row=row, entity_name="en") self._assert_key(key, (User, (1,), "en")) + + def test_extension(self): + mapper(User, users) + log = [] + class MyExt(SessionExtension): + def before_commit(self, session): + log.append('before_commit') + def after_commit(self, session): + log.append('after_commit') + def after_rollback(self, session): + log.append('after_rollback') + def before_flush(self, session, flush_context, objects): + log.append('before_flush') + def after_flush(self, session, flush_context): + log.append('after_flush') + def after_flush_postexec(self, session, flush_context): + log.append('after_flush_postexec') + def after_begin(self, session, transaction, connection): + log.append('after_begin') + sess = create_session(extension = MyExt()) + u = User() + sess.save(u) + sess.flush() + assert log == ['before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec'] + + log = [] + sess = create_session(transactional=True, extension=MyExt()) + u = User() + sess.save(u) + sess.flush() + assert log == ['before_flush', 'after_begin', 'after_flush', 'after_flush_postexec'] + + log = [] + u.user_name = 'ed' + sess.commit() + assert log == ['before_commit', 'before_flush', 'after_flush', 'after_flush_postexec', 'after_commit'] + + log = [] + sess.commit() + assert log == ['before_commit', 'after_commit'] - -if __name__ == "__main__": - testbase.main() + log = [] + sess = create_session(transactional=True, extension=MyExt(), bind=testing.db) + conn = sess.connection() + assert log == ['after_begin'] + + def test_pickled_update(self): + mapper(User, users) + sess1 = create_session() + sess2 = create_session() + + u1 = User() + sess1.save(u1) + + try: + sess2.save(u1) + assert False + except exceptions.InvalidRequestError, e: + assert "already attached to session" in str(e) + + u2 = pickle.loads(pickle.dumps(u1)) + + sess2.save(u2) + + def test_duplicate_update(self): + mapper(User, users) + Session = sessionmaker() + sess = Session() + + u1 = User() + sess.save(u1) + sess.flush() + assert u1.user_id is not None + + sess.expunge(u1) + + assert u1 not in sess + + u2 = sess.query(User).get(u1.user_id) + assert u2 is not None and u2 is not u1 + assert u2 in sess + + self.assertRaises(Exception, lambda: sess.update(u1)) + + sess.expunge(u2) + assert u2 not in sess + + u1.user_name = "John" + u2.user_name = "Doe" + + sess.update(u1) + assert u1 in sess + + sess.flush() + + sess.clear() + + u3 = sess.query(User).get(u1.user_id) + assert u3 is not u1 and u3 is not u2 and u3.user_name == u1.user_name + + def test_no_double_save(self): + sess = create_session() + class Foo(object): + def __init__(self): + sess.save(self) + class Bar(Foo): + def __init__(self): + sess.save(self) + Foo.__init__(self) + mapper(Foo, users) + mapper(Bar, users) + + b = Bar() + assert b in sess + assert len(list(sess)) == 1 + + +class ScopedSessionTest(ORMTest): + + def define_tables(self, metadata): + global table, table2 + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + table2 = Table('someothertable', metadata, + Column('id', Integer, primary_key=True), + Column('someid', None, ForeignKey('sometable.id')) + ) + + def test_basic(self): + Session = scoped_session(sessionmaker()) + + class SomeObject(fixtures.Base): + query = Session.query_property() + class SomeOtherObject(fixtures.Base): + query = Session.query_property() + + mapper(SomeObject, table, properties={ + 'options':relation(SomeOtherObject) + }) + mapper(SomeOtherObject, table2) + + s = SomeObject(id=1, data="hello") + sso = SomeOtherObject() + s.options.append(sso) + Session.save(s) + Session.commit() + Session.remove() + + self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), Session.query(SomeObject).one()) + self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), SomeObject.query.one()) + self.assertEquals(SomeOtherObject(someid=1), SomeOtherObject.query.filter(SomeOtherObject.someid==sso.someid).one()) + +class ScopedMapperTest(TestBase): + def setUpAll(self): + global metadata, table, table2 + metadata = MetaData(testing.db) + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30), nullable=False)) + table2 = Table('someothertable', metadata, + Column('id', Integer, primary_key=True), + Column('someid', None, ForeignKey('sometable.id')) + ) + metadata.create_all() + + def setUp(self): + global SomeObject, SomeOtherObject + class SomeObject(object):pass + class SomeOtherObject(object):pass + + global Session + + Session = scoped_session(create_session) + Session.mapper(SomeObject, table, properties={ + 'options':relation(SomeOtherObject) + }) + Session.mapper(SomeOtherObject, table2) + + s = SomeObject() + s.id = 1 + s.data = 'hello' + sso = SomeOtherObject() + s.options.append(sso) + Session.flush() + Session.clear() + + def tearDownAll(self): + metadata.drop_all() + + def tearDown(self): + for table in metadata.table_iterator(reverse=True): + table.delete().execute() + clear_mappers() + + def test_query(self): + sso = SomeOtherObject.query().first() + assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id + + def test_query_compiles(self): + class Foo(object): + pass + Session.mapper(Foo, table2) + assert hasattr(Foo, 'query') + + ext = MapperExtension() + + class Bar(object): + pass + Session.mapper(Bar, table2, extension=[ext]) + assert hasattr(Bar, 'query') + + class Baz(object): + pass + Session.mapper(Baz, table2, extension=ext) + assert hasattr(Baz, 'query') + + def test_validating_constructor(self): + s2 = SomeObject(someid=12) + s3 = SomeOtherObject(someid=123, bogus=345) + + class ValidatedOtherObject(object):pass + Session.mapper(ValidatedOtherObject, table2, validate=True) + + v1 = ValidatedOtherObject(someid=12) + try: + v2 = ValidatedOtherObject(someid=12, bogus=345) + assert False + except exceptions.ArgumentError: + pass + + def test_dont_clobber_methods(self): + class MyClass(object): + def expunge(self): + return "an expunge !" + + Session.mapper(MyClass, table2) + + assert MyClass().expunge() == "an expunge !" + + def _test_autoflush_saveoninit(self, on_init, autoflush=None): + Session = scoped_session( + sessionmaker(transactional=True, autoflush=True)) + + class Foo(object): + def __init__(self, data=None): + if autoflush is not None: + friends = Session.query(Foo).autoflush(autoflush).all() + else: + friends = Session.query(Foo).all() + self.data = data + + Session.mapper(Foo, table, save_on_init=on_init) + + a1 = Foo('an address') + Session.flush() + + def test_autoflush_saveoninit(self): + """Test save_on_init + query.autoflush()""" + self._test_autoflush_saveoninit(False) + self._test_autoflush_saveoninit(False, True) + self._test_autoflush_saveoninit(False, False) + + self.assertRaises(exceptions.DBAPIError, + self._test_autoflush_saveoninit, True) + self.assertRaises(exceptions.DBAPIError, + self._test_autoflush_saveoninit, True, True) + self._test_autoflush_saveoninit(True, False) + + +class ScopedMapperTest2(ORMTest): + def define_tables(self, metadata): + global table, table2 + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('type', String(30)) + + ) + table2 = Table('someothertable', metadata, + Column('id', Integer, primary_key=True), + Column('someid', None, ForeignKey('sometable.id')), + Column('somedata', String(30)), + ) + + def test_inheritance(self): + def expunge_list(l): + for x in l: + Session.expunge(x) + return l + + class BaseClass(fixtures.Base): + pass + class SubClass(BaseClass): + pass + + Session = scoped_session(sessionmaker()) + Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type) + Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass) + + b = BaseClass(data='b1') + s = SubClass(data='s1', somedata='somedata') + Session.commit() + Session.clear() + + assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all() + assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all() + + + +if __name__ == "__main__": + testenv.main() diff --git a/test/orm/sessioncontext.py b/test/orm/sessioncontext.py index 7a60b47c7e..c743dabf99 100644 --- a/test/orm/sessioncontext.py +++ b/test/orm/sessioncontext.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext @@ -12,10 +12,10 @@ users = Table('users', metadata, Column('user_name', String(40)), ) -class SessionContextTest(AssertMixin): +class SessionContextTest(TestBase, AssertsExecutionResults): def setUp(self): clear_mappers() - + def do_test(self, class_, context): """test session assignment on object creation""" obj = class_() @@ -32,10 +32,11 @@ class SessionContextTest(AssertMixin): del context.current assert context.current != new_session assert old_session == object_session(obj) - + obj2 = class_() assert context.current == object_session(obj2) - + + @testing.uses_deprecated('SessionContext') def test_mapper_extension(self): context = SessionContext(Session) class User(object): pass @@ -44,4 +45,4 @@ class SessionContextTest(AssertMixin): if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/orm/sharding/alltests.py b/test/orm/sharding/alltests.py index 0cdb838a9d..aab3993f32 100644 --- a/test/orm/sharding/alltests.py +++ b/test/orm/sharding/alltests.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest def suite(): @@ -15,4 +15,4 @@ def suite(): if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py index faa980cc27..d231b14a2c 100644 --- a/test/orm/sharding/shard.py +++ b/test/orm/sharding/shard.py @@ -1,18 +1,18 @@ -import testbase +import testenv; testenv.configure_for_tests() +import datetime, os from sqlalchemy import * +from sqlalchemy import exceptions, sql from sqlalchemy.orm import * - from sqlalchemy.orm.shard import ShardedSession -from sqlalchemy.sql import ColumnOperators -import datetime, operator, os -from testlib import PersistTest +from sqlalchemy.sql import operators +from testlib import * # TODO: ShardTest can be turned into a base for further subclasses -class ShardTest(PersistTest): +class ShardTest(TestBase): def setUpAll(self): global db1, db2, db3, db4, weather_locations, weather_reports - + db1 = create_engine('sqlite:///shard1.db') db2 = create_engine('sqlite:///shard2.db') db3 = create_engine('sqlite:///shard3.db') @@ -41,16 +41,18 @@ class ShardTest(PersistTest): Column('temperature', Float), Column('report_time', DateTime, default=datetime.datetime.now), ) - + for db in (db1, db2, db3, db4): meta.create_all(db) - + db1.execute(ids.insert(), nextid=1) self.setup_session() self.setup_mappers() - + def tearDownAll(self): + for db in (db1, db2, db3, db4): + db.connect().invalidate() for i in range(1,5): os.remove("shard%d.db" % i) @@ -63,14 +65,14 @@ class ShardTest(PersistTest): 'Europe':'europe', 'South America':'south_america' } - - def shard_chooser(mapper, instance): + + def shard_chooser(mapper, instance, clause=None): if isinstance(instance, WeatherLocation): return shard_lookup[instance.continent] else: return shard_chooser(mapper, instance.location) - def id_chooser(ident): + def id_chooser(query, ident): return ['north_america', 'asia', 'europe', 'south_america'] def query_chooser(query): @@ -79,9 +81,9 @@ class ShardTest(PersistTest): class FindContinent(sql.ClauseVisitor): def visit_binary(self, binary): if binary.left is weather_locations.c.continent: - if binary.operator == operator.eq: + if binary.operator == operators.eq: ids.append(shard_lookup[binary.right.value]) - elif binary.operator == ColumnOperators.in_op: + elif binary.operator == operators.in_op: for bind in binary.right.clauses: ids.append(shard_lookup[bind.value]) @@ -91,17 +93,19 @@ class ShardTest(PersistTest): else: return ids - def create_session(): - s = ShardedSession(shard_chooser, id_chooser, query_chooser) - s.bind_shard('north_america', db1) - s.bind_shard('asia', db2) - s.bind_shard('europe', db3) - s.bind_shard('south_america', db4) - return s + create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True) + + create_session.configure(shards={ + 'north_america':db1, + 'asia':db2, + 'europe':db3, + 'south_america':db4 + }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) + def setup_mappers(self): global WeatherLocation, Report - + class WeatherLocation(object): def __init__(self, continent, city): self.continent = continent @@ -112,10 +116,11 @@ class ShardTest(PersistTest): self.temperature = temperature mapper(WeatherLocation, weather_locations, properties={ - 'reports':relation(Report, backref='location') + 'reports':relation(Report, backref='location'), + 'city': deferred(weather_locations.c.city), }) - mapper(Report, weather_reports) + mapper(Report, weather_reports) def test_roundtrip(self): tokyo = WeatherLocation('Asia', 'Tokyo') @@ -133,10 +138,13 @@ class ShardTest(PersistTest): sess = create_session() for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: sess.save(c) - sess.flush() + sess.commit() sess.clear() + assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')] + assert db1.execute(weather_locations.select()).fetchall() == [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')] + t = sess.query(WeatherLocation).get(tokyo.id) assert t.city == tokyo.city assert t.reports[0].temperature == 80.0 @@ -144,11 +152,10 @@ class ShardTest(PersistTest): north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America') assert set([c.city for c in north_american_cities]) == set(['New York', 'Toronto']) - asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_('Europe', 'Asia')) + asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_(['Europe', 'Asia'])) assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin']) if __name__ == '__main__': - testbase.main() - \ No newline at end of file + testenv.main() diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index ae626db849..cd2a3005ea 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -1,139 +1,121 @@ -import testbase +# coding: utf-8 + +"""Tests unitofwork operations.""" + +import testenv; testenv.configure_for_tests() import pickleable from sqlalchemy import * +from sqlalchemy import exceptions, sql from sqlalchemy.orm import * -from sqlalchemy.orm.mapper import global_extensions -from sqlalchemy.orm import util as ormutil -from sqlalchemy.ext.sessioncontext import SessionContext -import sqlalchemy.ext.assignmapper as assignmapper from testlib import * from testlib.tables import * -from testlib import tables - -"""tests unitofwork operations""" - -class UnitOfWorkTest(AssertMixin): - def setUpAll(self): - global ctx, assign_mapper - ctx = SessionContext(create_session) - def assign_mapper(*args, **kwargs): - return assignmapper.assign_mapper(ctx, *args, **kwargs) - global_extensions.append(ctx.mapper_extension) - def tearDownAll(self): - global_extensions.remove(ctx.mapper_extension) - def tearDown(self): - ctx.current.clear() - clear_mappers() +from testlib import engines, tables, fixtures -class HistoryTest(UnitOfWorkTest): - def setUpAll(self): - tables.metadata.bind = testbase.db - UnitOfWorkTest.setUpAll(self) - users.create() - addresses.create() - def tearDownAll(self): - addresses.drop() - users.drop() - UnitOfWorkTest.tearDownAll(self) - - def testbackref(self): - s = create_session() + +# TODO: convert suite to not use Session.mapper, use fixtures.Base +# with explicit session.save() +Session = scoped_session(sessionmaker(autoflush=True, transactional=True)) +orm_mapper = mapper +mapper = Session.mapper + +class UnitOfWorkTest(object): + pass + +class HistoryTest(ORMTest): + metadata = tables.metadata + def define_tables(self, metadata): + pass + + def test_backref(self): + s = Session() class User(object):pass class Address(object):pass am = mapper(Address, addresses) m = mapper(User, users, properties = dict( addresses = relation(am, backref='user', lazy=False)) ) - + u = User(_sa_session=s) a = Address(_sa_session=s) a.user = u - #print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items()) - #print repr(u.addresses.added_items()) + self.assert_(u.addresses == [a]) - s.flush() + s.commit() - s.clear() - u = s.query(m).select()[0] + s.close() + u = s.query(m).all()[0] print u.addresses[0].user - -class VersioningTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - ctx.current.clear() +class VersioningTest(ORMTest): + def define_tables(self, metadata): global version_table - version_table = Table('version_test', MetaData(testbase.db), - Column('id', Integer, Sequence('version_test_seq'), primary_key=True ), + version_table = Table('version_test', metadata, + Column('id', Integer, Sequence('version_test_seq', optional=True), + primary_key=True ), Column('version_id', Integer, nullable=False), Column('value', String(40), nullable=False) ) - version_table.create() - def tearDownAll(self): - version_table.drop() - UnitOfWorkTest.tearDownAll(self) - def tearDown(self): - version_table.delete().execute() - UnitOfWorkTest.tearDown(self) - - def testbasic(self): - s = create_session() + + @engines.close_open_connections + def test_basic(self): + s = Session(scope=None) class Foo(object):pass - assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id) - f1 =Foo(value='f1', _sa_session=s) + mapper(Foo, version_table, version_id_col=version_table.c.version_id) + f1 = Foo(value='f1', _sa_session=s) f2 = Foo(value='f2', _sa_session=s) - s.flush() - + s.commit() + f1.value='f1rev2' - s.flush() - s2 = create_session() + s.commit() + s2 = Session() f1_s = s2.query(Foo).get(f1.id) f1_s.value='f1rev3' - s2.flush() + s2.commit() f1.value='f1rev3mine' success = False try: # a concurrent session has modified this, should throw # an exception - s.flush() + s.commit() except exceptions.ConcurrentModificationError, e: #print e success = True # Only dialects with a sane rowcount can detect the ConcurrentModificationError - if testbase.db.dialect.supports_sane_rowcount(): + if testing.db.dialect.supports_sane_rowcount: assert success - - s.clear() + + s.close() f1 = s.query(Foo).get(f1.id) f2 = s.query(Foo).get(f2.id) - + f1_s.value='f1rev4' - s2.flush() - + s2.commit() + s.delete(f1) s.delete(f2) success = False try: - s.flush() + s.commit() except exceptions.ConcurrentModificationError, e: #print e success = True - if testbase.db.dialect.supports_sane_rowcount(): + if testing.db.dialect.supports_sane_multi_rowcount: assert success - def testversioncheck(self): + @engines.close_open_connections + def test_versioncheck(self): """test that query.with_lockmode performs a 'version check' on an already loaded instance""" - s1 = create_session() + s1 = Session(scope=None) class Foo(object):pass - assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id) + mapper(Foo, version_table, version_id_col=version_table.c.version_id) f1s1 =Foo(value='f1', _sa_session=s1) - s1.flush() - s2 = create_session() + s1.commit() + s2 = Session() f1s2 = s2.query(Foo).get(f1s1.id) f1s2.value='f1 new value' - s2.flush() + s2.commit() try: # load, version is wrong s1.query(Foo).with_lockmode('read').get(f1s1.id) @@ -144,44 +126,36 @@ class VersioningTest(UnitOfWorkTest): s1.query(Foo).load(f1s1.id) # now assert version OK s1.query(Foo).with_lockmode('read').get(f1s1.id) - + # assert brand new load is OK too - s1.clear() + s1.close() s1.query(Foo).with_lockmode('read').get(f1s1.id) - - def testnoversioncheck(self): + + @engines.close_open_connections + def test_noversioncheck(self): """test that query.with_lockmode works OK when the mapper has no version id col""" - s1 = create_session() + s1 = Session() class Foo(object):pass - assign_mapper(Foo, version_table) + mapper(Foo, version_table) f1s1 =Foo(value='f1', _sa_session=s1) f1s1.version_id=0 - s1.flush() - s2 = create_session() + s1.commit() + s2 = Session() f1s2 = s2.query(Foo).with_lockmode('read').get(f1s1.id) assert f1s2.id == f1s1.id assert f1s2.value == f1s1.value - -class UnicodeTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - global metadata, uni_table, uni_table2 - metadata = MetaData(testbase.db) + +class UnicodeTest(ORMTest): + def define_tables(self, metadata): + global uni_table, uni_table2 uni_table = Table('uni_test', metadata, Column('id', Integer, Sequence("uni_test_id_seq", optional=True), primary_key=True), Column('txt', Unicode(50), unique=True)) uni_table2 = Table('uni2', metadata, Column('id', Integer, Sequence("uni2_test_id_seq", optional=True), primary_key=True), Column('txt', Unicode(50), ForeignKey(uni_table.c.txt))) - metadata.create_all() - def tearDownAll(self): - metadata.drop_all() - UnitOfWorkTest.tearDownAll(self) - def tearDown(self): - clear_mappers() - for t in metadata.table_iterator(reverse=True): - t.delete().execute() - def testbasic(self): + + def test_basic(self): class Test(object): def __init__(self, id, txt): self.id = id @@ -191,153 +165,255 @@ class UnicodeTest(UnitOfWorkTest): txt = u"\u0160\u0110\u0106\u010c\u017d" t1 = Test(id=1, txt = txt) self.assert_(t1.txt == txt) - ctx.current.flush() + Session.commit() self.assert_(t1.txt == txt) - def testrelation(self): + + def test_relation(self): class Test(object): def __init__(self, txt): self.txt = txt class Test2(object):pass - + mapper(Test, uni_table, properties={ 't2s':relation(Test2) }) mapper(Test2, uni_table2) - + txt = u"\u0160\u0110\u0106\u010c\u017d" t1 = Test(txt=txt) t1.t2s.append(Test2()) t1.t2s.append(Test2()) - ctx.current.flush() - ctx.current.clear() - t1 = ctx.current.query(Test).get_by(id=t1.id) + Session.commit() + Session.close() + t1 = Session.query(Test).filter_by(id=t1.id).one() assert len(t1.t2s) == 2 -class MutableTypesTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - global metadata, table - metadata = MetaData(testbase.db) +class UnicodeSchemaTest(ORMTest): + __unsupported_on__ = ('oracle', 'mssql', 'firebird', 'sybase', + 'access', 'maxdb') + __excluded_on__ = (('mysql', '<', (4, 1, 1)),) + + metadata = MetaData(engines.utf8_engine()) + + def define_tables(self, metadata): + global t1, t2 + + t1 = Table('unitable1', metadata, + Column(u'méil', Integer, primary_key=True, key='a'), + Column(u'\u6e2c\u8a66', Integer, key='b'), + Column('type', String(20)), + test_needs_fk=True, + ) + t2 = Table(u'Unitéble2', metadata, + Column(u'méil', Integer, primary_key=True, key="cc"), + Column(u'\u6e2c\u8a66', Integer, ForeignKey(u'unitable1.a'), key="d"), + Column(u'\u6e2c\u8a66_2', Integer, key="e"), + test_needs_fk=True, + ) + + def test_mapping(self): + class A(fixtures.Base):pass + class B(fixtures.Base):pass + + mapper(A, t1, properties={ + 't2s':relation(B), + }) + mapper(B, t2) + a1 = A() + b1 = B() + a1.t2s.append(b1) + Session.flush() + Session.clear() + new_a1 = Session.query(A).filter(t1.c.a == a1.a).one() + assert new_a1.a == a1.a + assert new_a1.t2s[0].d == b1.d + Session.clear() + + new_a1 = Session.query(A).options(eagerload('t2s')).filter(t1.c.a == a1.a).one() + assert new_a1.a == a1.a + assert new_a1.t2s[0].d == b1.d + Session.clear() + + new_a1 = Session.query(A).filter(A.a == a1.a).one() + assert new_a1.a == a1.a + assert new_a1.t2s[0].d == b1.d + Session.clear() + + def test_inheritance_mapping(self): + class A(fixtures.Base):pass + class B(A):pass + mapper(A, t1, polymorphic_on=t1.c.type, polymorphic_identity='a') + mapper(B, t2, inherits=A, polymorphic_identity='b') + a1 = A(b=5) + b1 = B(e=7) + + Session.flush() + Session.clear() + # TODO: somehow, not assigning to "l" first + # breaks the comparison ????? + l = Session.query(A).all() + assert [A(b=5), B(e=7)] == l + +class MutableTypesTest(ORMTest): + def define_tables(self, metadata): + global table table = Table('mutabletest', metadata, Column('id', Integer, Sequence('mutableidseq', optional=True), primary_key=True), Column('data', PickleType), - Column('value', Unicode(30))) - table.create() - def tearDownAll(self): - table.drop() - UnitOfWorkTest.tearDownAll(self) + Column('val', Unicode(30))) - def testbasic(self): + def test_basic(self): """test that types marked as MutableType get changes detected on them""" class Foo(object):pass mapper(Foo, table) f1 = Foo() f1.data = pickleable.Bar(4,5) - ctx.current.flush() - ctx.current.clear() - f2 = ctx.current.query(Foo).get_by(id=f1.id) + Session.commit() + Session.close() + f2 = Session.query(Foo).filter_by(id=f1.id).one() assert f2.data == f1.data f2.data.y = 19 - ctx.current.flush() - ctx.current.clear() - f3 = ctx.current.query(Foo).get_by(id=f1.id) + assert f2 in Session.dirty + Session.commit() + Session.close() + f3 = Session.query(Foo).filter_by(id=f1.id).one() print f2.data, f3.data assert f3.data != f1.data assert f3.data == pickleable.Bar(4, 19) - def testmutablechanges(self): + def test_mutablechanges(self): """test that mutable changes are detected or not detected correctly""" class Foo(object):pass mapper(Foo, table) f1 = Foo() f1.data = pickleable.Bar(4,5) - f1.value = unicode('hi') - ctx.current.flush() + f1.val = unicode('hi') + Session.commit() def go(): - ctx.current.flush() - self.assert_sql_count(testbase.db, go, 0) - f1.value = unicode('someothervalue') - self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ + Session.commit() + self.assert_sql_count(testing.db, go, 0) + f1.val = unicode('someothervalue') + self.assert_sql(testing.db, lambda: Session.commit(), [ ( - "UPDATE mutabletest SET value=:value WHERE mutabletest.id = :mutabletest_id", - {'mutabletest_id': f1.id, 'value': u'someothervalue'} + "UPDATE mutabletest SET val=:val WHERE mutabletest.id = :mutabletest_id", + {'mutabletest_id': f1.id, 'val': u'someothervalue'} ), ]) - f1.value = unicode('hi') + f1.val = unicode('hi') f1.data.x = 9 - self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ + self.assert_sql(testing.db, lambda: Session.commit(), [ ( - "UPDATE mutabletest SET data=:data, value=:value WHERE mutabletest.id = :mutabletest_id", - {'mutabletest_id': f1.id, 'value': u'hi', 'data':f1.data} + "UPDATE mutabletest SET data=:data, val=:val WHERE mutabletest.id = :mutabletest_id", + {'mutabletest_id': f1.id, 'val': u'hi', 'data':f1.data} ), ]) - - - def testnocomparison(self): + + def test_nocomparison(self): """test that types marked as MutableType get changes detected on them when the type has no __eq__ method""" class Foo(object):pass mapper(Foo, table) f1 = Foo() f1.data = pickleable.BarWithoutCompare(4,5) - ctx.current.flush() - + Session.commit() + def go(): - ctx.current.flush() - self.assert_sql_count(testbase.db, go, 0) - - ctx.current.clear() + Session.commit() + self.assert_sql_count(testing.db, go, 0) + + Session.close() - f2 = ctx.current.query(Foo).get_by(id=f1.id) + f2 = Session.query(Foo).filter_by(id=f1.id).one() def go(): - ctx.current.flush() - self.assert_sql_count(testbase.db, go, 0) + Session.commit() + self.assert_sql_count(testing.db, go, 0) f2.data.y = 19 def go(): - ctx.current.flush() - self.assert_sql_count(testbase.db, go, 1) - - ctx.current.clear() - f3 = ctx.current.query(Foo).get_by(id=f1.id) + Session.commit() + self.assert_sql_count(testing.db, go, 1) + + Session.close() + f3 = Session.query(Foo).filter_by(id=f1.id).one() print f2.data, f3.data assert (f3.data.x, f3.data.y) == (4,19) def go(): - ctx.current.flush() - self.assert_sql_count(testbase.db, go, 0) - - def testunicode(self): + Session.commit() + self.assert_sql_count(testing.db, go, 0) + + def test_unicode(self): """test that two equivalent unicode values dont get flagged as changed. - + apparently two equal unicode objects dont compare via "is" in all cases, so this tests the compare_values() call on types.String and its usage via types.Unicode.""" class Foo(object):pass mapper(Foo, table) f1 = Foo() - f1.value = u'hi' - ctx.current.flush() - ctx.current.clear() - f1 = ctx.current.get(Foo, f1.id) - f1.value = u'hi' + f1.val = u'hi' + Session.commit() + Session.close() + f1 = Session.get(Foo, f1.id) + f1.val = u'hi' def go(): - ctx.current.flush() - self.assert_sql_count(testbase.db, go, 0) - - -class PKTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - global table, table2, table3, metadata - metadata = MetaData(testbase.db) + Session.commit() + self.assert_sql_count(testing.db, go, 0) + +class MutableTypesTest2(ORMTest): + def define_tables(self, metadata): + global table + import operator + table = Table('mutabletest', metadata, + Column('id', Integer, Sequence('mutableidseq', optional=True), primary_key=True), + Column('data', PickleType(comparator=operator.eq)), + ) + + def test_dicts(self): + """dictionaries dont pickle the same way twice, sigh.""" + + class Foo(object):pass + mapper(Foo, table) + f1 = Foo() + f1.data = [{'personne': {'nom': u'Smith', 'pers_id': 1, 'prenom': u'john', 'civilite': u'Mr', \ + 'int_3': False, 'int_2': False, 'int_1': u'23', 'VenSoir': True, 'str_1': u'Test', \ + 'SamMidi': False, 'str_2': u'chien', 'DimMidi': False, 'SamSoir': True, 'SamAcc': False}}] + + Session.commit() + def go(): + Session.commit() + self.assert_sql_count(testing.db, go, 0) + + f1.data = [{'personne': {'nom': u'Smith', 'pers_id': 1, 'prenom': u'john', 'civilite': u'Mr', \ + 'int_3': False, 'int_2': False, 'int_1': u'23', 'VenSoir': True, 'str_1': u'Test', \ + 'SamMidi': False, 'str_2': u'chien', 'DimMidi': False, 'SamSoir': True, 'SamAcc': False}}] + + def go(): + Session.commit() + self.assert_sql_count(testing.db, go, 0) + + f1.data[0]['personne']['VenSoir']= False + def go(): + Session.commit() + self.assert_sql_count(testing.db, go, 1) + + Session.clear() + f = Session.query(Foo).get(f1.id) + assert f.data == [{'personne': {'nom': u'Smith', 'pers_id': 1, 'prenom': u'john', 'civilite': u'Mr', \ + 'int_3': False, 'int_2': False, 'int_1': u'23', 'VenSoir': False, 'str_1': u'Test', \ + 'SamMidi': False, 'str_2': u'chien', 'DimMidi': False, 'SamSoir': True, 'SamAcc': False}}] + +class PKTest(ORMTest): + def define_tables(self, metadata): + global table, table2, table3 + table = Table( - 'multipk', metadata, + 'multipk', metadata, Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True), Column('multi_rev', Integer, primary_key=True), Column('name', String(50), nullable=False), Column('value', String(100)) ) - + table2 = Table('multipk2', metadata, Column('pk_col_1', String(30), primary_key=True), Column('pk_col_2', String(30), primary_key=True), @@ -349,16 +425,11 @@ class PKTest(UnitOfWorkTest): Column('date_assigned', Date, key='assigned', primary_key=True), Column('data', String(30), ) ) - metadata.create_all() - def tearDownAll(self): - metadata.drop_all() - UnitOfWorkTest.tearDownAll(self) - - # not support on sqlite since sqlite's auto-pk generation only works with - # single column primary keys - @testing.unsupported('sqlite') - def testprimarykey(self): + # not supported on sqlite since sqlite's auto-pk generation only works with + # single column primary keys + @testing.fails_on('sqlite') + def test_primarykey(self): class Entry(object): pass Entry.mapper = mapper(Entry, table) @@ -366,13 +437,13 @@ class PKTest(UnitOfWorkTest): e.name = 'entry1' e.value = 'this is entry 1' e.multi_rev = 2 - ctx.current.flush() - ctx.current.clear() + Session.commit() + Session.close() e2 = Query(Entry).get((e.multi_id, 2)) self.assert_(e is not e2 and e._instance_key == e2._instance_key) - + # this one works with sqlite since we are manually setting up pk values - def testmanualpk(self): + def test_manualpk(self): class Entry(object): pass Entry.mapper = mapper(Entry, table2) @@ -380,9 +451,9 @@ class PKTest(UnitOfWorkTest): e.pk_col_1 = 'pk1' e.pk_col_2 = 'pk1_related' e.data = 'im the data' - ctx.current.flush() - - def testkeypks(self): + Session.commit() + + def test_keypks(self): import datetime class Entity(object): pass @@ -392,59 +463,38 @@ class PKTest(UnitOfWorkTest): e.secondary = 'pk2' e.assigned = datetime.date.today() e.data = 'some more data' - ctx.current.flush() + Session.commit() - def testpksimmutable(self): - class Entry(object): - pass - mapper(Entry, table) - e = Entry() - e.multi_id=5 - e.multi_rev=5 - e.name='somename' - ctx.current.flush() - e.multi_rev=6 - e.name = 'someothername' - try: - ctx.current.flush() - assert False - except exceptions.FlushError, fe: - assert str(fe) == "Can't change the identity of instance Entry@%s in session (existing identity: (%s, (5, 5), None); new identity: (%s, (5, 6), None))" % (hex(id(e)), repr(e.__class__), repr(e.__class__)) - - -class ForeignPKTest(UnitOfWorkTest): +class ForeignPKTest(ORMTest): """tests mapper detection of the relationship direction when parent/child tables are joined on their primary keys""" - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - global metadata, people, peoplesites - metadata = MetaData(testbase.db) + + def define_tables(self, metadata): + global people, peoplesites + people = Table("people", metadata, Column('person', String(10), primary_key=True), Column('firstname', String(10)), Column('lastname', String(10)), ) - + peoplesites = Table("peoplesites", metadata, - Column('person', String(10), ForeignKey("people.person"), + Column('person', String(10), ForeignKey("people.person"), primary_key=True), Column('site', String(10)), ) - metadata.create_all() - def tearDownAll(self): - metadata.drop_all() - UnitOfWorkTest.tearDownAll(self) - def testbasic(self): + + def test_basic(self): class PersonSite(object):pass class Person(object):pass m1 = mapper(PersonSite, peoplesites) m2 = mapper(Person, people, properties = { - 'sites' : relation(PersonSite), + 'sites' : relation(PersonSite), }, ) - + compile_mappers() assert list(m2.get_property('sites').foreign_keys) == [peoplesites.c.person] p = Person() p.person = 'im the key' @@ -452,14 +502,71 @@ class ForeignPKTest(UnitOfWorkTest): ps = PersonSite() ps.site = 'asdf' p.sites.append(ps) - ctx.current.flush() + Session.commit() assert people.count(people.c.person=='im the key').scalar() == peoplesites.count(peoplesites.c.person=='im the key').scalar() == 1 -class PassiveDeletesTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - global metadata, mytable,myothertable - metadata = MetaData(testbase.db) +class ClauseAttributesTest(ORMTest): + def define_tables(self, metadata): + global users_table + users_table = Table('users', metadata, + Column('id', Integer, Sequence('users_id_seq', optional=True), primary_key=True), + Column('name', String(30)), + Column('counter', Integer, default=1)) + + def test_update(self): + class User(object): + pass + mapper(User, users_table) + u = User(name='test') + sess = Session() + sess.save(u) + sess.flush() + assert u.counter == 1 + u.counter = User.counter + 1 + sess.flush() + + def go(): + assert (u.counter == 2) is True # ensure its not a ClauseElement + self.assert_sql_count(testing.db, go, 1) + + def test_multi_update(self): + class User(object): + pass + mapper(User, users_table) + u = User(name='test') + sess = Session() + sess.save(u) + sess.flush() + assert u.counter == 1 + u.name = 'test2' + u.counter = User.counter + 1 + sess.flush() + def go(): + assert u.name == 'test2' + assert (u.counter == 2) is True + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + u = sess.query(User).get(u.id) + assert u.name == 'test2' + assert u.counter == 2 + + @testing.unsupported('mssql') + def test_insert(self): + class User(object): + pass + mapper(User, users_table) + u = User(name='test', counter=select([5])) + sess = Session() + sess.save(u) + sess.flush() + assert (u.counter == 5) is True + + +class PassiveDeletesTest(ORMTest): + def define_tables(self, metadata): + global mytable,myothertable + mytable = Table('mytable', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -474,130 +581,275 @@ class PassiveDeletesTest(UnitOfWorkTest): test_needs_fk=True, ) - metadata.create_all() - def tearDownAll(self): - metadata.drop_all() - UnitOfWorkTest.tearDownAll(self) - @testing.unsupported('sqlite') - def testbasic(self): + def test_basic(self): class MyClass(object): pass class MyOtherClass(object): pass - + mapper(MyOtherClass, myothertable) mapper(MyClass, mytable, properties={ 'children':relation(MyOtherClass, passive_deletes=True, cascade="all") }) - sess = ctx.current + sess = Session mc = MyClass() mc.children.append(MyOtherClass()) mc.children.append(MyOtherClass()) mc.children.append(MyOtherClass()) mc.children.append(MyOtherClass()) sess.save(mc) - sess.flush() - sess.clear() + sess.commit() + sess.close() assert myothertable.count().scalar() == 4 mc = sess.query(MyClass).get(mc.id) sess.delete(mc) - sess.flush() + sess.commit() assert mytable.count().scalar() == 0 assert myothertable.count().scalar() == 0 - - -class DefaultTest(UnitOfWorkTest): +class ExtraPassiveDeletesTest(ORMTest): + def define_tables(self, metadata): + global mytable,myothertable + + mytable = Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + test_needs_fk=True, + ) + + myothertable = Table('myothertable', metadata, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer), + Column('data', String(30)), + ForeignKeyConstraint(['parent_id'],['mytable.id']), # no CASCADE, the same as ON DELETE RESTRICT + test_needs_fk=True, + ) + + def test_assertions(self): + class MyClass(object): + pass + class MyOtherClass(object): + pass + + mapper(MyOtherClass, myothertable) + + try: + mapper(MyClass, mytable, properties={ + 'children':relation(MyOtherClass, passive_deletes='all', cascade="all") + }) + assert False + except exceptions.ArgumentError, e: + assert str(e) == "Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade" + + @testing.unsupported('sqlite') + def test_extra_passive(self): + class MyClass(object): + pass + class MyOtherClass(object): + pass + + mapper(MyOtherClass, myothertable) + + mapper(MyClass, mytable, properties={ + 'children':relation(MyOtherClass, passive_deletes='all', cascade="save-update") + }) + + sess = Session + mc = MyClass() + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + sess.save(mc) + sess.commit() + + assert myothertable.count().scalar() == 4 + mc = sess.query(MyClass).get(mc.id) + sess.delete(mc) + self.assertRaises(exceptions.DBAPIError, sess.commit) + + @testing.unsupported('sqlite') + def test_extra_passive_2(self): + class MyClass(object): + pass + class MyOtherClass(object): + pass + + mapper(MyOtherClass, myothertable) + + mapper(MyClass, mytable, properties={ + 'children':relation(MyOtherClass, passive_deletes='all', cascade="save-update") + }) + + sess = Session + mc = MyClass() + mc.children.append(MyOtherClass()) + sess.save(mc) + sess.commit() + + assert myothertable.count().scalar() == 1 + mc = sess.query(MyClass).get(mc.id) + sess.delete(mc) + mc.children[0].data = 'some new data' + self.assertRaises(exceptions.DBAPIError, sess.commit) + + +class DefaultTest(ORMTest): """tests that when saving objects whose table contains DefaultGenerators, either python-side, preexec or database-side, - the newly saved instances receive all the default values either through a post-fetch or getting the pre-exec'ed + the newly saved instances receive all the default values either through a post-fetch or getting the pre-exec'ed defaults back from the engine.""" - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - db = testbase.db - use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite') + + def define_tables(self, metadata): + db = testing.db + use_string_defaults = testing.against('postgres', 'oracle', 'sqlite') + global hohoval, althohoval if use_string_defaults: hohotype = String(30) - self.hohoval = "im hoho" - self.althohoval = "im different hoho" + hohoval = "im hoho" + althohoval = "im different hoho" else: hohotype = Integer - self.hohoval = 9 - self.althohoval = 15 - global default_table - metadata = MetaData(db) + hohoval = 9 + althohoval = 15 + + global default_table, secondary_table default_table = Table('default_test', metadata, - Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True), - Column('hoho', hohotype, PassiveDefault(str(self.hohoval))), - Column('counter', Integer, PassiveDefault("7")), - Column('foober', String(30), default="im foober", onupdate="im the update") + Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True), + Column('hoho', hohotype, PassiveDefault(str(hohoval))), + Column('counter', Integer, default=func.length("1234567")), + Column('foober', String(30), default="im foober", onupdate="im the update"), ) - default_table.create() - def tearDownAll(self): - default_table.drop() - UnitOfWorkTest.tearDownAll(self) - def testinsert(self): + + secondary_table = Table('secondary_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + if testing.against('postgres', 'oracle'): + default_table.append_column(Column('secondary_id', Integer, Sequence('sec_id_seq'), unique=True)) + secondary_table.append_column(Column('fk_val', Integer, ForeignKey('default_test.secondary_id'))) + else: + secondary_table.append_column(Column('hoho', hohotype, ForeignKey('default_test.hoho'))) + + def test_insert(self): class Hoho(object):pass - assign_mapper(Hoho, default_table) - h1 = Hoho(hoho=self.althohoval) + mapper(Hoho, default_table) + + h1 = Hoho(hoho=althohoval) h2 = Hoho(counter=12) - h3 = Hoho(hoho=self.althohoval, counter=12) + h3 = Hoho(hoho=althohoval, counter=12) h4 = Hoho() h5 = Hoho(foober='im the new foober') - ctx.current.flush() - self.assert_(h1.hoho==self.althohoval) - self.assert_(h3.hoho==self.althohoval) - self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval) - self.assert_(h3.counter == h2.counter == 12) - self.assert_(h1.counter == h4.counter==h5.counter==7) - self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') - self.assert_(h5.foober=='im the new foober') - ctx.current.clear() - l = Query(Hoho).select() + Session.commit() + + self.assert_(h1.hoho==althohoval) + self.assert_(h3.hoho==althohoval) + + def go(): + # test deferred load of attribues, one select per instance + self.assert_(h2.hoho==h4.hoho==h5.hoho==hohoval) + self.assert_sql_count(testing.db, go, 3) + + def go(): + self.assert_(h1.counter == h4.counter==h5.counter==7) + self.assert_sql_count(testing.db, go, 1) + + def go(): + self.assert_(h3.counter == h2.counter == 12) + self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') + self.assert_(h5.foober=='im the new foober') + self.assert_sql_count(testing.db, go, 0) + + Session.close() + + l = Hoho.query.all() + (h1, h2, h3, h4, h5) = l - self.assert_(h1.hoho==self.althohoval) - self.assert_(h3.hoho==self.althohoval) - self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval) + + self.assert_(h1.hoho==althohoval) + self.assert_(h3.hoho==althohoval) + self.assert_(h2.hoho==h4.hoho==h5.hoho==hohoval) self.assert_(h3.counter == h2.counter == 12) self.assert_(h1.counter == h4.counter==h5.counter==7) self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') self.assert_(h5.foober=='im the new foober') - - def testinsertnopostfetch(self): + + def test_eager_defaults(self): + class Hoho(object):pass + mapper(Hoho, default_table, eager_defaults=True) + h1 = Hoho() + Session.commit() + + def go(): + self.assert_(h1.hoho==hohoval) + self.assert_sql_count(testing.db, go, 0) + + def test_insert_nopostfetch(self): # populates the PassiveDefaults explicitly so there is no "post-update" class Hoho(object):pass - assign_mapper(Hoho, default_table) + mapper(Hoho, default_table) + h1 = Hoho(hoho="15", counter="15") - ctx.current.flush() - self.assert_(h1.hoho=="15") - self.assert_(h1.counter=="15") - self.assert_(h1.foober=="im foober") - - def testupdate(self): + + Session.commit() + def go(): + self.assert_(h1.hoho=="15") + self.assert_(h1.counter=="15") + self.assert_(h1.foober=="im foober") + self.assert_sql_count(testing.db, go, 0) + + def test_update(self): class Hoho(object):pass - assign_mapper(Hoho, default_table) + mapper(Hoho, default_table) h1 = Hoho() - ctx.current.flush() - self.assert_(h1.foober == 'im foober') + Session.commit() + self.assertEquals(h1.foober, 'im foober') h1.counter = 19 - ctx.current.flush() - self.assert_(h1.foober == 'im the update') - -class OneToManyTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - tables.create() - def tearDownAll(self): - tables.drop() - UnitOfWorkTest.tearDownAll(self) - def tearDown(self): - tables.delete() - UnitOfWorkTest.tearDown(self) + Session.commit() + self.assertEquals(h1.foober, 'im the update') + + def test_used_in_relation(self): + """test that a server-side generated default can be used as the target of a foreign key""" + + class Hoho(fixtures.Base): + pass + class Secondary(fixtures.Base): + pass + mapper(Hoho, default_table, properties={ + 'secondaries':relation(Secondary) + }, save_on_init=False) + + mapper(Secondary, secondary_table, save_on_init=False) + h1 = Hoho() + s1 = Secondary(data='s1') + h1.secondaries.append(s1) + Session.save(h1) + Session.commit() + Session.clear() + + self.assertEquals(Session.query(Hoho).get(h1.id), Hoho(hoho=hohoval, secondaries=[Secondary(data='s1')])) + + h1 = Session.query(Hoho).get(h1.id) + h1.secondaries.append(Secondary(data='s2')) + Session.commit() + Session.clear() + + self.assertEquals(Session.query(Hoho).get(h1.id), + Hoho(hoho=hohoval, secondaries=[Secondary(data='s1'), Secondary(data='s2')]) + ) + + +class OneToManyTest(ORMTest): + metadata = tables.metadata + + def define_tables(self, metadata): + pass - def testonetomany_1(self): + def test_onetomany_1(self): """test basic save of one to many.""" m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True) @@ -612,11 +864,11 @@ class OneToManyTest(UnitOfWorkTest): a2.email_address = 'lala@test.org' u.addresses.append(a2) print repr(u.addresses) - ctx.current.flush() + Session.commit() - usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall() + usertable = users.select(users.c.user_id.in_([u.user_id])).execute().fetchall() self.assertEqual(usertable[0].values(), [u.user_id, 'one2manytester']) - addresstable = addresses.select(addresses.c.address_id.in_(a.address_id, a2.address_id), order_by=[addresses.c.email_address]).execute().fetchall() + addresstable = addresses.select(addresses.c.address_id.in_([a.address_id, a2.address_id]), order_by=[addresses.c.email_address]).execute().fetchall() self.assertEqual(addresstable[0].values(), [a2.address_id, u.user_id, 'lala@test.org']) self.assertEqual(addresstable[1].values(), [a.address_id, u.user_id, 'one2many@test.org']) @@ -625,13 +877,13 @@ class OneToManyTest(UnitOfWorkTest): a2.email_address = 'somethingnew@foo.com' - ctx.current.flush() + Session.commit() addresstable = addresses.select(addresses.c.address_id == addressid).execute().fetchall() self.assertEqual(addresstable[0].values(), [addressid, userid, 'somethingnew@foo.com']) self.assert_(u.user_id == userid and a2.address_id == addressid) - def testonetomany_2(self): + def test_onetomany_2(self): """digs deeper into modifying the child items of an object to insure the correct updates take place""" m = mapper(User, users, properties = dict( @@ -653,7 +905,7 @@ class OneToManyTest(UnitOfWorkTest): a3 = Address() a3.email_address = 'emailaddress3' - ctx.current.flush() + Session.commit() # modify user2 directly, append an address to user1. # upon commit, user2 should be updated, user1 should not @@ -661,7 +913,7 @@ class OneToManyTest(UnitOfWorkTest): u2.user_name = 'user2modified' u1.addresses.append(a3) del u1.addresses[0] - self.assert_sql(testbase.db, lambda: ctx.current.flush(), + self.assert_sql(testing.db, lambda: Session.commit(), [ ( "UPDATE users SET user_name=:user_name WHERE users.user_id = :users_user_id", @@ -676,7 +928,7 @@ class OneToManyTest(UnitOfWorkTest): ), ]) - def testchildmove(self): + def test_childmove(self): """tests moving a child from one parent to the other, then deleting the first parent, properly updates the child with the new parent. this tests the 'trackparent' option in the attributes module.""" m = mapper(User, users, properties = dict( @@ -689,16 +941,16 @@ class OneToManyTest(UnitOfWorkTest): a = Address() a.email_address = 'address1' u1.addresses.append(a) - ctx.current.flush() + Session.commit() del u1.addresses[0] u2.addresses.append(a) - ctx.current.delete(u1) - ctx.current.flush() - ctx.current.clear() - u2 = ctx.current.get(User, u2.user_id) + Session.delete(u1) + Session.commit() + Session.close() + u2 = Session.get(User, u2.user_id) assert len(u2.addresses) == 1 - def testchildmove_2(self): + def test_childmove_2(self): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True) )) @@ -709,15 +961,15 @@ class OneToManyTest(UnitOfWorkTest): a = Address() a.email_address = 'address1' u1.addresses.append(a) - ctx.current.flush() + Session.commit() del u1.addresses[0] u2.addresses.append(a) - ctx.current.flush() - ctx.current.clear() - u2 = ctx.current.get(User, u2.user_id) + Session.commit() + Session.close() + u2 = Session.get(User, u2.user_id) assert len(u2.addresses) == 1 - def testo2mdeleteparent(self): + def test_o2m_delete_parent(self): m = mapper(User, users, properties = dict( address = relation(mapper(Address, addresses), lazy = True, uselist = False, private = False) )) @@ -726,12 +978,12 @@ class OneToManyTest(UnitOfWorkTest): u.user_name = 'one2onetester' u.address = a u.address.email_address = 'myonlyaddress@foo.com' - ctx.current.flush() - ctx.current.delete(u) - ctx.current.flush() - self.assert_(a.address_id is not None and a.user_id is None and not ctx.current.identity_map.has_key(u._instance_key) and ctx.current.identity_map.has_key(a._instance_key)) + Session.commit() + Session.delete(u) + Session.commit() + self.assert_(a.address_id is not None and a.user_id is None and u._instance_key not in Session.identity_map and a._instance_key in Session.identity_map) - def testonetoone(self): + def test_onetoone(self): m = mapper(User, users, properties = dict( address = relation(mapper(Address, addresses), lazy = True, uselist = False) )) @@ -739,13 +991,13 @@ class OneToManyTest(UnitOfWorkTest): u.user_name = 'one2onetester' u.address = Address() u.address.email_address = 'myonlyaddress@foo.com' - ctx.current.flush() + Session.commit() u.user_name = 'imnew' - ctx.current.flush() + Session.commit() u.address.email_address = 'imnew@foo.com' - ctx.current.flush() + Session.commit() - def testbidirectional(self): + def test_bidirectional(self): m1 = mapper(User, users) m2 = mapper(Address, addresses, properties = dict( @@ -759,7 +1011,7 @@ class OneToManyTest(UnitOfWorkTest): a = Address() a.email_address = 'testaddress' a.user = u - ctx.current.flush() + Session.commit() print repr(u.addresses) x = False try: @@ -771,18 +1023,18 @@ class OneToManyTest(UnitOfWorkTest): if x: self.assert_(False, "User addresses element should be scalar based") - ctx.current.delete(u) - ctx.current.flush() + Session.delete(u) + Session.commit() - def testdoublerelation(self): + def test_doublerelation(self): m2 = mapper(Address, addresses) m = mapper(User, users, properties={ 'boston_addresses' : relation(m2, primaryjoin= - and_(users.c.user_id==Address.c.user_id, - Address.c.email_address.like('%boston%'))), + and_(users.c.user_id==addresses.c.user_id, + addresses.c.email_address.like('%boston%'))), 'newyork_addresses' : relation(m2, primaryjoin= - and_(users.c.user_id==Address.c.user_id, - Address.c.email_address.like('%newyork%'))), + and_(users.c.user_id==addresses.c.user_id, + addresses.c.email_address.like('%newyork%'))), }) u = User() a = Address() @@ -792,18 +1044,15 @@ class OneToManyTest(UnitOfWorkTest): u.boston_addresses.append(a) u.newyork_addresses.append(b) - ctx.current.flush() + Session.commit() -class SaveTest(UnitOfWorkTest): +class SaveTest(ORMTest): + metadata = tables.metadata + def define_tables(self, metadata): + pass - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - tables.create() - def tearDownAll(self): - tables.drop() - UnitOfWorkTest.tearDownAll(self) - def setUp(self): + super(SaveTest, self).setUp() keywords.insert().execute( dict(name='blue'), dict(name='red'), @@ -814,11 +1063,7 @@ class SaveTest(UnitOfWorkTest): dict(name='square') ) - def tearDown(self): - tables.delete() - UnitOfWorkTest.tearDown(self) - - def testbasic(self): + def test_basic(self): # save two users u = User() u.user_name = 'savetester' @@ -826,59 +1071,82 @@ class SaveTest(UnitOfWorkTest): u2 = User() u2.user_name = 'savetester2' - ctx.current.save(u) - - ctx.current.flush([u]) - ctx.current.flush() + Session.save(u) + + Session.flush([u]) + Session.commit() # assert the first one retreives the same from the identity map - nu = ctx.current.get(m, u.user_id) + nu = Session.get(m, u.user_id) print "U: " + repr(u) + "NU: " + repr(nu) self.assert_(u is nu) - + # clear out the identity map, so next get forces a SELECT - ctx.current.clear() + Session.close() # check it again, identity should be different but ids the same - nu = ctx.current.get(m, u.user_id) + nu = Session.get(m, u.user_id) self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester') + Session.close() # change first users name and save - ctx.current.update(u) + Session.update(u) u.user_name = 'modifiedname' - assert u in ctx.current.dirty - ctx.current.flush() + assert u in Session.dirty + Session.commit() # select both - #ctx.current.clear() - userlist = Query(m).select(users.c.user_id.in_(u.user_id, u2.user_id), order_by=[users.c.user_name]) + #Session.close() + userlist = User.query.filter(users.c.user_id.in_([u.user_id, u2.user_id])).order_by([users.c.user_name]).all() print repr(u.user_id), repr(userlist[0].user_id), repr(userlist[0].user_name) self.assert_(u.user_id == userlist[0].user_id and userlist[0].user_name == 'modifiedname') self.assert_(u2.user_id == userlist[1].user_id and userlist[1].user_name == 'savetester2') - def testlazyattrcommit(self): + def test_synonym(self): + class User(object): + def _get_name(self): + return "User:" + self.user_name + def _set_name(self, name): + self.user_name = name + ":User" + name = property(_get_name, _set_name) + + mapper(User, users, properties={ + 'name':synonym('user_name') + }) + + u = User() + u.name = "some name" + assert u.name == 'User:some name:User' + Session.save(u) + Session.flush() + Session.clear() + u = Session.query(User).first() + assert u.name == 'User:some name:User' + + def test_lazyattr_commit(self): """tests that when a lazy-loaded list is unloaded, and a commit occurs, that the 'passive' call on that list does not blow away its value""" + m1 = mapper(User, users, properties = { 'addresses': relation(mapper(Address, addresses)) }) - + u = User() u.addresses.append(Address()) u.addresses.append(Address()) u.addresses.append(Address()) u.addresses.append(Address()) - ctx.current.flush() - ctx.current.clear() - ulist = ctx.current.query(m1).select() + Session.commit() + Session.close() + ulist = Session.query(m1).all() u1 = ulist[0] u1.user_name = 'newname' - ctx.current.flush() + Session.commit() self.assert_(len(u1.addresses) == 4) - - def testinherits(self): + + def test_inherits(self): m1 = mapper(User, users) - + class AddressUser(User): """a user object that also has the users mailing address.""" pass @@ -888,96 +1156,127 @@ class SaveTest(UnitOfWorkTest): AddressUser, addresses, inherits=m1 ) - + au = AddressUser() - ctx.current.flush() - ctx.current.clear() - l = ctx.current.query(AddressUser).selectone() + Session.commit() + Session.close() + l = Session.query(AddressUser).one() self.assert_(l.user_id == au.user_id and l.address_id == au.address_id) - - def testdeferred(self): - """test that a deferred load within a flush() doesnt screw up the connection""" + + def test_deferred(self): + """test deferred column operations""" + mapper(User, users, properties={ 'user_name':deferred(users.c.user_name) }) + + # dont set deferred attribute, commit session u = User() u.user_id=42 - ctx.current.flush() - + Session.commit() + + # assert that changes get picked up + u.user_name = 'some name' + Session.commit() + assert list(Session.execute(users.select(), mapper=User)) == [(42, 'some name')] + Session.clear() + + # assert that a set operation doesn't trigger a load operation + u = Session.query(User).filter(User.user_name=='some name').one() + def go(): + u.user_name = 'some other name' + self.assert_sql_count(testing.db, go, 0) + Session.flush() + assert list(Session.execute(users.select(), mapper=User)) == [(42, 'some other name')] + + Session.clear() + + # test assigning None to an unloaded deferred also works + u = Session.query(User).filter(User.user_name=='some other name').one() + u.user_name = None + Session.flush() + assert list(Session.execute(users.select(), mapper=User)) == [(42, None)] + + + # why no support on oracle ? because oracle doesn't save + # "blank" strings; it saves a single space character. + @testing.unsupported('oracle') def test_dont_update_blanks(self): mapper(User, users) u = User() u.user_name = "" - ctx.current.flush() - ctx.current.clear() - u = ctx.current.query(User).get(u.user_id) + Session.commit() + Session.close() + u = Session.query(User).get(u.user_id) u.user_name = "" def go(): - ctx.current.flush() - self.assert_sql_count(testbase.db, go, 0) + Session.commit() + self.assert_sql_count(testing.db, go, 0) - def testmultitable(self): + def test_multitable(self): """tests a save of an object where each instance spans two tables. also tests redefinition of the keynames for the column properties.""" usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) - m = mapper(User, usersaddresses, + m = mapper(User, usersaddresses, properties = dict( - email = addresses.c.email_address, + email = addresses.c.email_address, foo_id = [users.c.user_id, addresses.c.user_id], ) ) - + u = User() u.user_name = 'multitester' u.email = 'multi@test.org' - ctx.current.flush() + Session.commit() id = m.primary_key_from_instance(u) - ctx.current.clear() - - u = ctx.current.get(User, id) + Session.close() + + u = Session.get(User, id) assert u.user_name == 'multitester' - - usertable = users.select(users.c.user_id.in_(u.foo_id)).execute().fetchall() + + usertable = users.select(users.c.user_id.in_([u.foo_id])).execute().fetchall() self.assertEqual(usertable[0].values(), [u.foo_id, 'multitester']) - addresstable = addresses.select(addresses.c.address_id.in_(u.address_id)).execute().fetchall() + addresstable = addresses.select(addresses.c.address_id.in_([u.address_id])).execute().fetchall() self.assertEqual(addresstable[0].values(), [u.address_id, u.foo_id, 'multi@test.org']) u.email = 'lala@hey.com' u.user_name = 'imnew' - ctx.current.flush() + Session.commit() - usertable = users.select(users.c.user_id.in_(u.foo_id)).execute().fetchall() + usertable = users.select(users.c.user_id.in_([u.foo_id])).execute().fetchall() self.assertEqual(usertable[0].values(), [u.foo_id, 'imnew']) - addresstable = addresses.select(addresses.c.address_id.in_(u.address_id)).execute().fetchall() + addresstable = addresses.select(addresses.c.address_id.in_([u.address_id])).execute().fetchall() self.assertEqual(addresstable[0].values(), [u.address_id, u.foo_id, 'lala@hey.com']) - ctx.current.clear() - u = ctx.current.get(User, id) + Session.close() + u = Session.get(User, id) assert u.user_name == 'imnew' - - def testhistoryget(self): + + def test_history_get(self): """tests that the history properly lazy-fetches data when it wasnt otherwise loaded""" mapper(User, users, properties={ 'addresses':relation(Address, cascade="all, delete-orphan") }) mapper(Address, addresses) - + u = User() u.addresses.append(Address()) u.addresses.append(Address()) - ctx.current.flush() - ctx.current.clear() - u = ctx.current.query(User).get(u.user_id) - ctx.current.delete(u) - ctx.current.flush() + Session.commit() + Session.close() + u = Session.query(User).get(u.user_id) + Session.delete(u) + Session.commit() assert users.count().scalar() == 0 assert addresses.count().scalar() == 0 - - - - def testbatchmode(self): + + + + def test_batchmode(self): + """test the 'batch=False' flag on mapper()""" + class TestExtension(MapperExtension): def before_insert(self, mapper, connection, instance): self.current_instance = instance @@ -988,34 +1287,29 @@ class SaveTest(UnitOfWorkTest): u1.username = 'user1' u2 = User() u2.username = 'user2' - ctx.current.flush() - + Session.commit() + clear_mappers() - + m = mapper(User, users, extension=TestExtension()) u1 = User() u1.username = 'user1' u2 = User() u2.username = 'user2' try: - ctx.current.flush() + Session.commit() assert False except AssertionError: assert True - - -class ManyToOneTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - tables.create() - def tearDownAll(self): - tables.drop() - UnitOfWorkTest.tearDownAll(self) - def tearDown(self): - tables.delete() - UnitOfWorkTest.tearDown(self) - - def testm2oonetoone(self): + + +class ManyToOneTest(ORMTest): + metadata = tables.metadata + + def define_tables(self, metadata): + pass + + def test_m2o_onetoone(self): # TODO: put assertion in here !!! m = mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy = True, uselist = False) @@ -1034,12 +1328,12 @@ class ManyToOneTest(UnitOfWorkTest): a.user = User() a.user.user_name = elem['user_name'] objects.append(a) - - ctx.current.flush() + + Session.commit() objects[2].email_address = 'imnew@foo.bar' objects[3].user = User() objects[3].user.user_name = 'imnewlyadded' - self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ + self.assert_sql(testing.db, lambda: Session.commit(), [ ( "INSERT INTO users (user_name) VALUES (:user_name)", {'user_name': 'imnewlyadded'} @@ -1048,11 +1342,11 @@ class ManyToOneTest(UnitOfWorkTest): "UPDATE email_addresses SET email_address=:email_address WHERE email_addresses.address_id = :email_addresses_address_id": lambda ctx: {'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id} , - + "UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id": lambda ctx: {'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id} }, - + ], with_sequences=[ ( @@ -1063,17 +1357,17 @@ class ManyToOneTest(UnitOfWorkTest): "UPDATE email_addresses SET email_address=:email_address WHERE email_addresses.address_id = :email_addresses_address_id": lambda ctx: {'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id} , - + "UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id": lambda ctx: {'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id} }, - + ]) l = sql.select([users, addresses], sql.and_(users.c.user_id==addresses.c.user_id, addresses.c.address_id==a.address_id)).execute() assert l.fetchone().values() == [a.user.user_id, 'asdf8d', a.address_id, a.user_id, 'theater@foo.com'] - def testmanytoone_1(self): + def test_manytoone_1(self): m = mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy = True) )) @@ -1081,22 +1375,22 @@ class ManyToOneTest(UnitOfWorkTest): a1.email_address = 'emailaddress1' u1 = User() u1.user_name='user1' - + a1.user = u1 - ctx.current.flush() - ctx.current.clear() - a1 = ctx.current.query(Address).get(a1.address_id) - u1 = ctx.current.query(User).get(u1.user_id) + Session.commit() + Session.close() + a1 = Session.query(Address).get(a1.address_id) + u1 = Session.query(User).get(u1.user_id) assert a1.user is u1 a1.user = None - ctx.current.flush() - ctx.current.clear() - a1 = ctx.current.query(Address).get(a1.address_id) - u1 = ctx.current.query(User).get(u1.user_id) + Session.commit() + Session.close() + a1 = Session.query(Address).get(a1.address_id) + u1 = Session.query(User).get(u1.user_id) assert a1.user is None - def testmanytoone_2(self): + def test_manytoone_2(self): m = mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy = True) )) @@ -1108,23 +1402,23 @@ class ManyToOneTest(UnitOfWorkTest): u1.user_name='user1' a1.user = u1 - ctx.current.flush() - ctx.current.clear() - a1 = ctx.current.query(Address).get(a1.address_id) - a2 = ctx.current.query(Address).get(a2.address_id) - u1 = ctx.current.query(User).get(u1.user_id) + Session.commit() + Session.close() + a1 = Session.query(Address).get(a1.address_id) + a2 = Session.query(Address).get(a2.address_id) + u1 = Session.query(User).get(u1.user_id) assert a1.user is u1 a1.user = None a2.user = u1 - ctx.current.flush() - ctx.current.clear() - a1 = ctx.current.query(Address).get(a1.address_id) - a2 = ctx.current.query(Address).get(a2.address_id) - u1 = ctx.current.query(User).get(u1.user_id) + Session.commit() + Session.close() + a1 = Session.query(Address).get(a1.address_id) + a2 = Session.query(Address).get(a2.address_id) + u1 = Session.query(User).get(u1.user_id) assert a1.user is None assert a2.user is u1 - def testmanytoone_3(self): + def test_manytoone_3(self): m = mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy = True) )) @@ -1136,33 +1430,53 @@ class ManyToOneTest(UnitOfWorkTest): u2.user_name='user2' a1.user = u1 - ctx.current.flush() - ctx.current.clear() - a1 = ctx.current.query(Address).get(a1.address_id) - u1 = ctx.current.query(User).get(u1.user_id) - u2 = ctx.current.query(User).get(u2.user_id) + Session.commit() + Session.close() + a1 = Session.query(Address).get(a1.address_id) + u1 = Session.query(User).get(u1.user_id) + u2 = Session.query(User).get(u2.user_id) assert a1.user is u1 - + a1.user = u2 - ctx.current.flush() - ctx.current.clear() - a1 = ctx.current.query(Address).get(a1.address_id) - u1 = ctx.current.query(User).get(u1.user_id) - u2 = ctx.current.query(User).get(u2.user_id) + Session.commit() + Session.close() + a1 = Session.query(Address).get(a1.address_id) + u1 = Session.query(User).get(u1.user_id) + u2 = Session.query(User).get(u2.user_id) assert a1.user is u2 - -class ManyToManyTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - tables.create() - def tearDownAll(self): - tables.drop() - UnitOfWorkTest.tearDownAll(self) - def tearDown(self): - tables.delete() - UnitOfWorkTest.tearDown(self) - def testmanytomany(self): + def test_bidirectional_noload(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=None) + }) + mapper(Address, addresses) + + sess = Session() + + # try it on unsaved objects + u1 = User() + a1 = Address() + a1.user = u1 + sess.save(u1) + sess.flush() + sess.clear() + + a1 = sess.query(Address).get(a1.address_id) + + a1.user = None + sess.flush() + sess.clear() + assert sess.query(Address).get(a1.address_id).user is None + assert sess.query(User).get(u1.user_id).addresses == [] + + +class ManyToManyTest(ORMTest): + metadata = tables.metadata + + def define_tables(self, metadata): + pass + + def test_manytomany(self): items = orderitems keywordmapper = mapper(Keyword, keywords) @@ -1185,8 +1499,8 @@ class ManyToManyTest(UnitOfWorkTest): objects.append(item) item.item_name = elem['item_name'] item.keywords = [] - if len(elem['keywords'][1]): - klist = ctx.current.query(keywordmapper).select(keywords.c.name.in_(*[e['name'] for e in elem['keywords'][1]])) + if elem['keywords'][1]: + klist = Session.query(keywordmapper).filter(keywords.c.name.in_([e['name'] for e in elem['keywords'][1]])) else: klist = [] khash = {} @@ -1200,16 +1514,16 @@ class ManyToManyTest(UnitOfWorkTest): k.name = kname item.keywords.append(k) - ctx.current.flush() - - l = ctx.current.query(m).select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name]) + Session.commit() + + l = Session.query(m).filter(items.c.item_name.in_([e['item_name'] for e in data[1:]])).order_by(items.c.item_name).all() self.assert_result(l, *data) objects[4].item_name = 'item4updated' k = Keyword() k.name = 'yellow' objects[5].keywords.append(k) - self.assert_sql(testbase.db, lambda:ctx.current.flush(), [ + self.assert_sql(testing.db, lambda:Session.commit(), [ { "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id": {'item_name': 'item4updated', 'items_item_id': objects[4].item_id} @@ -1221,7 +1535,7 @@ class ManyToManyTest(UnitOfWorkTest): lambda ctx: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}] ) ], - + with_sequences = [ { "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id": @@ -1238,21 +1552,21 @@ class ManyToManyTest(UnitOfWorkTest): objects[2].keywords.append(k) dkid = objects[5].keywords[1].keyword_id del objects[5].keywords[1] - self.assert_sql(testbase.db, lambda:ctx.current.flush(), [ + self.assert_sql(testing.db, lambda:Session.commit(), [ ( "DELETE FROM itemkeywords WHERE itemkeywords.item_id = :item_id AND itemkeywords.keyword_id = :keyword_id", [{'item_id': objects[5].item_id, 'keyword_id': dkid}] ), - ( + ( "INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)", lambda ctx: [{'item_id': objects[2].item_id, 'keyword_id': k.keyword_id}] ) ]) - - ctx.current.delete(objects[3]) - ctx.current.flush() - def testmanytomanyremove(self): + Session.delete(objects[3]) + Session.commit() + + def test_manytomany_remove(self): """tests that setting a list-based attribute to '[]' properly affects the history and allows the many-to-many rows to be deleted""" keywordmapper = mapper(Keyword, keywords) @@ -1266,30 +1580,30 @@ class ManyToManyTest(UnitOfWorkTest): k2 = Keyword() i.keywords.append(k1) i.keywords.append(k2) - ctx.current.flush() - + Session.commit() + assert itemkeywords.count().scalar() == 2 i.keywords = [] - ctx.current.flush() + Session.commit() assert itemkeywords.count().scalar() == 0 - def testscalar(self): + def test_scalar(self): """test that dependency.py doesnt try to delete an m2m relation referencing None.""" - + mapper(Keyword, keywords) mapper(Item, orderitems, properties = dict( keyword = relation(Keyword, secondary=itemkeywords, uselist=False), )) - + i = Item() - ctx.current.flush() - ctx.current.delete(i) - ctx.current.flush() - - + Session.commit() + Session.delete(i) + Session.commit() + - def testmanytomanyupdate(self): + + def test_manytomany_update(self): """tests some history operations on a many to many""" class Keyword(object): def __init__(self, name): @@ -1298,7 +1612,7 @@ class ManyToManyTest(UnitOfWorkTest): return other.__class__ == Keyword and other.name == self.name def __repr__(self): return "Keyword(%s, %s)" % (getattr(self, 'keyword_id', 'None'), self.name) - + mapper(Keyword, keywords) mapper(Item, orderitems, properties = dict( keywords = relation(Keyword, secondary=itemkeywords, lazy=False, order_by=keywords.c.name), @@ -1310,19 +1624,20 @@ class ManyToManyTest(UnitOfWorkTest): item.keywords.append(k1) item.keywords.append(k2) item.keywords.append(k3) - ctx.current.flush() - + Session.commit() + item.keywords = [] item.keywords.append(k1) item.keywords.append(k2) - ctx.current.flush() - - ctx.current.clear() - item = ctx.current.query(Item).get(item.item_id) + Session.commit() + + Session.close() + item = Session.query(Item).get(item.item_id) print [k1, k2] + print item.keywords assert item.keywords == [k1, k2] - - def testassociation(self): + + def test_association(self): """basic test of an association object""" class IKAssociation(object): def __repr__(self): @@ -1342,29 +1657,29 @@ class ManyToManyTest(UnitOfWorkTest): )) data = [Item, - {'item_name': 'a_item1', 'keywords' : (IKAssociation, + {'item_name': 'a_item1', 'keywords' : (IKAssociation, [ {'keyword' : (Keyword, {'name': 'big'})}, - {'keyword' : (Keyword, {'name': 'green'})}, + {'keyword' : (Keyword, {'name': 'green'})}, {'keyword' : (Keyword, {'name': 'purple'})}, {'keyword' : (Keyword, {'name': 'round'})} ] - ) + ) }, - {'item_name': 'a_item2', 'keywords' : (IKAssociation, + {'item_name': 'a_item2', 'keywords' : (IKAssociation, [ {'keyword' : (Keyword, {'name': 'huge'})}, - {'keyword' : (Keyword, {'name': 'violet'})}, + {'keyword' : (Keyword, {'name': 'violet'})}, {'keyword' : (Keyword, {'name': 'yellow'})} ] - ) + ) }, - {'item_name': 'a_item3', 'keywords' : (IKAssociation, + {'item_name': 'a_item3', 'keywords' : (IKAssociation, [ {'keyword' : (Keyword, {'name': 'big'})}, - {'keyword' : (Keyword, {'name': 'blue'})}, + {'keyword' : (Keyword, {'name': 'blue'})}, ] - ) + ) } ] for elem in data[1:]: @@ -1373,7 +1688,7 @@ class ManyToManyTest(UnitOfWorkTest): item.keywords = [] for kname in [e['keyword'][1]['name'] for e in elem['keywords'][1]]: try: - k = Query(keywordmapper).select(keywords.c.name == kname)[0] + k = Keyword.query.filter(keywords.c.name == kname)[0] except IndexError: k = Keyword() k.name= kname @@ -1381,65 +1696,27 @@ class ManyToManyTest(UnitOfWorkTest): ik.keyword = k item.keywords.append(ik) - ctx.current.flush() - ctx.current.clear() - l = Query(m).select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name]) + Session.commit() + Session.close() + l = Item.query.filter(items.c.item_name.in_([e['item_name'] for e in data[1:]])).order_by(items.c.item_name).all() self.assert_result(l, *data) - def testm2mmultitable(self): - # many-to-many join on an association table - j = join(users, userkeywords, - users.c.user_id==userkeywords.c.user_id).join(keywords, - userkeywords.c.keyword_id==keywords.c.keyword_id) - print "PK", j.primary_key - # a class - class KeywordUser(object): - pass - - # map to it - the identity of a KeywordUser object will be - # (user_id, keyword_id) since those are the primary keys involved - m = mapper(KeywordUser, j, properties={ - 'user_id':[users.c.user_id, userkeywords.c.user_id], - 'keyword_id':[userkeywords.c.keyword_id, keywords.c.keyword_id], - 'keyword_name':keywords.c.name, - }, ) - - k = KeywordUser() - k.user_name = 'keyworduser' - k.keyword_name = 'a keyword' - ctx.current.flush() - - id = (k.user_id, k.keyword_id) - ctx.current.clear() - k = ctx.current.query(KeywordUser).get(id) - assert k.user_name == 'keyworduser' - assert k.keyword_name == 'a keyword' - - -class SaveTest2(UnitOfWorkTest): +class SaveTest2(ORMTest): - def setUp(self): - ctx.current.clear() - clear_mappers() - global meta, users, addresses - meta = MetaData(testbase.db) - users = Table('users', meta, + def define_tables(self, metadata): + global users, addresses + users = Table('users', metadata, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), Column('user_name', String(20)), ) - addresses = Table('email_addresses', meta, + addresses = Table('email_addresses', metadata, Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True), Column('rel_user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(20)), ) - meta.create_all() - def tearDown(self): - meta.drop_all() - UnitOfWorkTest.tearDown(self) - - def testbackwardsnonmatch(self): + def test_m2o_nonmatch(self): m = mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy = True, uselist = False) )) @@ -1454,7 +1731,7 @@ class SaveTest2(UnitOfWorkTest): a.user = User() a.user.user_name = elem['user_name'] objects.append(a) - self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ + self.assert_sql(testing.db, lambda: Session.commit(), [ ( "INSERT INTO users (user_name) VALUES (:user_name)", {'user_name': 'thesub'} @@ -1472,7 +1749,7 @@ class SaveTest2(UnitOfWorkTest): {'rel_user_id': 2, 'email_address': 'thesdf@asdf.com'} ) ], - + with_sequences = [ ( "INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)", @@ -1494,39 +1771,27 @@ class SaveTest2(UnitOfWorkTest): ) -class SaveTest3(UnitOfWorkTest): - def setUpAll(self): - global st3_metadata, t1, t2, t3 - - UnitOfWorkTest.setUpAll(self) +class SaveTest3(ORMTest): + def define_tables(self, metadata): + global t1, t2, t3 - st3_metadata = MetaData(testbase.db) - t1 = Table('items', st3_metadata, + t1 = Table('items', metadata, Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True), Column('item_name', VARCHAR(50)), ) - t3 = Table('keywords', st3_metadata, + t3 = Table('keywords', metadata, Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True), Column('name', VARCHAR(50)), ) - t2 = Table('assoc', st3_metadata, + t2 = Table('assoc', metadata, Column('item_id', INT, ForeignKey("items")), Column('keyword_id', INT, ForeignKey("keywords")), Column('foo', Boolean, default=True) ) - st3_metadata.create_all() - def tearDownAll(self): - st3_metadata.drop_all() - UnitOfWorkTest.tearDownAll(self) - - def setUp(self): - pass - def tearDown(self): - pass - def testmanytomanyxtracolremove(self): + def test_manytomany_xtracol_delete(self): """test that a many-to-many on a table that has an extra column can properly delete rows from the table without referencing the extra column""" mapper(Keyword, t3) @@ -1540,14 +1805,248 @@ class SaveTest3(UnitOfWorkTest): k2 = Keyword() i.keywords.append(k1) i.keywords.append(k2) - ctx.current.flush() + Session.commit() assert t2.count().scalar() == 2 i.keywords = [] print i.keywords - ctx.current.flush() + Session.commit() assert t2.count().scalar() == 0 +class BooleanColTest(ORMTest): + def define_tables(self, metadata): + global t + t =Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30)), + Column('value', Boolean)) + + def test_boolean(self): + # use the regular mapper + from sqlalchemy.orm import mapper + + class T(fixtures.Base): + pass + mapper(T, t) + + sess = create_session() + t1 = T(value=True, name="t1") + t2 = T(value=False, name="t2") + t3 = T(value=True, name="t3") + sess.save(t1) + sess.save(t2) + sess.save(t3) + + sess.flush() + + for clear in (False, True): + if clear: + sess.clear() + self.assertEquals(sess.query(T).all(), [T(value=True, name="t1"), T(value=False, name="t2"), T(value=True, name="t3")]) + if clear: + sess.clear() + self.assertEquals(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"),T(value=True, name="t3")]) + if clear: + sess.clear() + self.assertEquals(sess.query(T).filter(T.value==False).all(), [T(value=False, name="t2")]) + + t2 = sess.query(T).get(t2.id) + t2.value = True + sess.flush() + self.assertEquals(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"), T(value=True, name="t2"), T(value=True, name="t3")]) + t2.value = False + sess.flush() + self.assertEquals(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"),T(value=True, name="t3")]) + + +class RowSwitchTest(ORMTest): + def define_tables(self, metadata): + global t1, t2, t3, t1t3 + + global T1, T2, T3 + + Session.remove() + + # parent + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30), nullable=False)) + + # onetomany + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30), nullable=False), + Column('t1id', Integer, ForeignKey('t1.id'),nullable=False), + ) + + # associated + t3 = Table('t3', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30), nullable=False), + ) + #manytomany + t1t3 = Table('t1t3', metadata, + Column('t1id', Integer, ForeignKey('t1.id'),nullable=False), + Column('t3id', Integer, ForeignKey('t3.id'),nullable=False), + ) + + class T1(fixtures.Base): + pass + + class T2(fixtures.Base): + pass + + class T3(fixtures.Base): + pass + + def tearDown(self): + Session.remove() + super(RowSwitchTest, self).tearDown() + + def test_onetomany(self): + mapper(T1, t1, properties={ + 't2s':relation(T2, cascade="all, delete-orphan") + }) + mapper(T2, t2) + + sess = Session(autoflush=False) + + o1 = T1(data='some t1', id=1) + o1.t2s.append(T2(data='some t2', id=1)) + o1.t2s.append(T2(data='some other t2', id=2)) + + sess.save(o1) + sess.flush() + + assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some t1')] + assert list(sess.execute(t2.select(), mapper=T1)) == [(1, 'some t2', 1), (2, 'some other t2', 1)] + + o2 = T1(data='some other t1', id=o1.id, t2s=[ + T2(data='third t2', id=3), + T2(data='fourth t2', id=4), + ]) + sess.delete(o1) + sess.save(o2) + sess.flush() + + assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some other t1')] + assert list(sess.execute(t2.select(), mapper=T1)) == [(3, 'third t2', 1), (4, 'fourth t2', 1)] + + def test_manytomany(self): + mapper(T1, t1, properties={ + 't3s':relation(T3, secondary=t1t3, cascade="all, delete-orphan") + }) + mapper(T3, t3) + + sess = Session(autoflush=False) + + o1 = T1(data='some t1', id=1) + o1.t3s.append(T3(data='some t3', id=1)) + o1.t3s.append(T3(data='some other t3', id=2)) + + sess.save(o1) + sess.flush() + + assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some t1')] + assert rowset(sess.execute(t1t3.select(), mapper=T1)) == set([(1,1), (1, 2)]) + assert list(sess.execute(t3.select(), mapper=T1)) == [(1, 'some t3'), (2, 'some other t3')] + + o2 = T1(data='some other t1', id=1, t3s=[ + T3(data='third t3', id=3), + T3(data='fourth t3', id=4), + ]) + sess.delete(o1) + sess.save(o2) + sess.flush() + + assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some other t1')] + assert list(sess.execute(t3.select(), mapper=T1)) == [(3, 'third t3'), (4, 'fourth t3')] + + def test_manytoone(self): + + mapper(T2, t2, properties={ + 't1':relation(T1) + }) + mapper(T1, t1) + + sess = Session(autoflush=False) + + o1 = T2(data='some t2', id=1) + o1.t1 = T1(data='some t1', id=1) + + sess.save(o1) + sess.flush() + + assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some t1')] + assert list(sess.execute(t2.select(), mapper=T1)) == [(1, 'some t2', 1)] + + o2 = T2(data='some other t2', id=1, t1=T1(data='some other t1', id=2)) + sess.delete(o1) + sess.delete(o1.t1) + sess.save(o2) + sess.flush() + + assert list(sess.execute(t1.select(), mapper=T1)) == [(2, 'some other t1')] + assert list(sess.execute(t2.select(), mapper=T1)) == [(1, 'some other t2', 2)] + +class TransactionTest(ORMTest): + __unsupported_on__ = ('mysql', 'mssql') + + # sqlite doesn't have deferrable constraints, but it allows them to + # be specified. it'll raise immediately post-INSERT, instead of at + # COMMIT. either way, this test should pass. + + def define_tables(self, metadata): + global t1, T1, t2, T2 + + Session.remove() + + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True)) + + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('t1_id', Integer, + ForeignKey('t1.id', deferrable=True, initially='deferred') + )) + + # deferred_constraint = \ + # DDL("ALTER TABLE t2 ADD CONSTRAINT t2_t1_id_fk FOREIGN KEY (t1_id) " + # "REFERENCES t1 (id) DEFERRABLE INITIALLY DEFERRED") + # deferred_constraint.execute_at('after-create', t2) + # t1.create() + # t2.create() + # t2.append_constraint(ForeignKeyConstraint(['t1_id'], ['t1.id'])) + + class T1(fixtures.Base): + pass + + class T2(fixtures.Base): + pass + + orm_mapper(T1, t1) + orm_mapper(T2, t2) + + def test_close_transaction_on_commit_fail(self): + Session = sessionmaker(autoflush=False, transactional=False) + sess = Session() + + # with a deferred constraint, this fails at COMMIT time instead + # of at INSERT time. + sess.save(T2(t1_id=123)) + + try: + sess.flush() + assert False + except: + # Flush needs to rollback also when commit fails + assert sess.transaction is None + + # todo: on 8.3 at least, the failed commit seems to close the cursor? + # needs investigation. leaving in the DDL above now to help verify + # that the new deferrable support on FK isn't involved in this issue. + if testing.against('postgres'): + t1.bind.engine.dispose() if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/perf/cascade_speed.py b/test/perf/cascade_speed.py index 34d046381f..dbf41a7f75 100644 --- a/test/perf/cascade_speed.py +++ b/test/perf/cascade_speed.py @@ -1,10 +1,10 @@ -import testbase +import testenv; testenv.simple_setup() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * from timeit import Timer import sys + meta = MetaData() orders = Table('orders', meta, @@ -62,7 +62,7 @@ class TimeTrial(object): for valueid in range(5): val = Value() val.attribute = attr - + def run(self, number): s = create_session() self.order = order = Order() diff --git a/test/perf/insertspeed.py b/test/perf/insertspeed.py new file mode 100644 index 0000000000..32877560eb --- /dev/null +++ b/test/perf/insertspeed.py @@ -0,0 +1,110 @@ +import testenv; testenv.simple_setup() +import sys, time +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import profiling + +db = create_engine('sqlite://') +metadata = MetaData(db) +Person_table = Table('Person', metadata, + Column('name', String(40)), + Column('sex', Integer), + Column('age', Integer)) + + +def sa_unprofiled_insertmany(n): + i = Person_table.insert() + i.execute([{'name':'John Doe','sex':1,'age':35} for j in xrange(n)]) + +def sqlite_unprofiled_insertmany(n): + conn = db.connect().connection + c = conn.cursor() + persons = [('john doe', 1, 35) for i in xrange(n)] + c.executemany("insert into Person(name, sex, age) values (?,?,?)", persons) + +@profiling.profiled('sa_profiled_insert_many', always=True) +def sa_profiled_insert_many(n): + i = Person_table.insert() + i.execute([{'name':'John Doe','sex':1,'age':35} for j in xrange(n)]) + s = Person_table.select() + r = s.execute() + res = [[value for value in row] for row in r.fetchall()] + +def sqlite_unprofiled_insert(n): + conn = db.connect().connection + c = conn.cursor() + for j in xrange(n): + c.execute("insert into Person(name, sex, age) values (?,?,?)", + ('john doe', 1, 35)) + +def sa_unprofiled_insert(n): + # Another option is to build Person_table.insert() outside of the + # loop. But it doesn't make much of a difference, so might as well + # use the worst-case/naive version here. + for j in xrange(n): + Person_table.insert().execute({'name':'John Doe','sex':1,'age':35}) + +@profiling.profiled('sa_profiled_insert', always=True) +def sa_profiled_insert(n): + i = Person_table.insert() + for j in xrange(n): + i.execute({'name':'John Doe','sex':1,'age':35}) + s = Person_table.select() + r = s.execute() + res = [[value for value in row] for row in r.fetchall()] + +def run_timed(fn, label, *args, **kw): + metadata.drop_all() + metadata.create_all() + + sys.stdout.write("%s (%s): " % (label, ', '.join([str(a) for a in args]))) + sys.stdout.flush() + + t = time.clock() + fn(*args, **kw) + t2 = time.clock() + + sys.stdout.write("%0.2f seconds\n" % (t2 - t)) + +def run_profiled(fn, label, *args, **kw): + metadata.drop_all() + metadata.create_all() + + print "%s (%s)" % (label, ', '.join([str(a) for a in args])) + fn(*args, **kw) + +def all(): + try: + print "Bulk INSERTS via executemany():\n" + + run_timed(sqlite_unprofiled_insertmany, + 'pysqlite bulk insert', + 50000) + + run_timed(sa_unprofiled_insertmany, + 'SQLAlchemy bulk insert', + 50000) + + run_profiled(sa_profiled_insert_many, + 'SQLAlchemy bulk insert/select, profiled', + 1000) + + print "\nIndividual INSERTS via execute():\n" + + run_timed(sqlite_unprofiled_insert, + "pysqlite individual insert", + 50000) + + run_timed(sa_unprofiled_insert, + "SQLAlchemy individual insert", + 50000) + + run_profiled(sa_profiled_insert, + 'SQLAlchemy individual insert/select, profiled', + 1000) + + finally: + metadata.drop_all() + +if __name__ == '__main__': + all() diff --git a/test/perf/masscreate.py b/test/perf/masscreate.py index 346a725e35..ae32f83e2c 100644 --- a/test/perf/masscreate.py +++ b/test/perf/masscreate.py @@ -1,7 +1,7 @@ # times how long it takes to create 26000 objects -import testbase +import testenv; testenv.simple_setup() -from sqlalchemy.orm.attributes import * +from sqlalchemy.orm import attributes import time import gc @@ -13,18 +13,17 @@ class User(object): class Address(object): pass -attr_manager = AttributeManager() if manage_attributes: - attr_manager.register_attribute(User, 'id', uselist=False) - attr_manager.register_attribute(User, 'name', uselist=False) - attr_manager.register_attribute(User, 'addresses', uselist=True, trackparent=True) - attr_manager.register_attribute(Address, 'email', uselist=False) + attributes.register_attribute(User, 'id', False, False) + attributes.register_attribute(User, 'name', False, False) + attributes.register_attribute(User, 'addresses', True, False, trackparent=True) + attributes.register_attribute(Address, 'email', False, False) now = time.time() for i in range(0,130): u = User() if init_attributes: - attr_manager.init_attr(u) + attributes.manage(u) u.id = i u.name = "user " + str(i) if not manage_attributes: @@ -32,7 +31,7 @@ for i in range(0,130): for j in range(0,200): a = Address() if init_attributes: - attr_manager.init_attr(a) + attributes.manage(a) a.email = 'foo@bar.com' u.addresses.append(a) # gc.collect() diff --git a/test/perf/masscreate2.py b/test/perf/masscreate2.py index 2e29a63272..25d4b49153 100644 --- a/test/perf/masscreate2.py +++ b/test/perf/masscreate2.py @@ -1,37 +1,36 @@ -import testbase +import testenv; testenv.simple_setup() import gc import random, string -from sqlalchemy.orm.attributes import * +from sqlalchemy.orm import attributes # with this test, run top. make sure the Python process doenst grow in size arbitrarily. class User(object): pass - + class Address(object): pass -attr_manager = AttributeManager() -attr_manager.register_attribute(User, 'id', uselist=False) -attr_manager.register_attribute(User, 'name', uselist=False) -attr_manager.register_attribute(User, 'addresses', uselist=True) -attr_manager.register_attribute(Address, 'email', uselist=False) -attr_manager.register_attribute(Address, 'user', uselist=False) - +attributes.register_attribute(User, 'id', False, False) +attributes.register_attribute(User, 'name', False, False) +attributes.register_attribute(User, 'addresses', True, False) +attributes.register_attribute(Address, 'email', False, False) +attributes.register_attribute(Address, 'user', False, False) + for i in xrange(1000): for j in xrange(1000): u = User() + attributes.manage(u) u.name = str(random.randint(0, 100000000)) for k in xrange(10): a = Address() a.email_address = str(random.randint(0, 100000000)) + attributes.manage(a) u.addresses.append(a) a.user = u print "clearing" #managed_attributes.clear() gc.collect() - - diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py index f1c0f292b0..bc2834ff74 100644 --- a/test/perf/masseagerload.py +++ b/test/perf/masseagerload.py @@ -1,5 +1,4 @@ -import testbase -import hotshot, hotshot.stats +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * @@ -7,11 +6,11 @@ from testlib import * NUM = 500 DIVISOR = 50 -meta = MetaData(testbase.db) -items = Table('items', meta, +meta = MetaData(testing.db) +items = Table('items', meta, Column('item_id', Integer, primary_key=True), Column('value', String(100))) -subitems = Table('subitems', meta, +subitems = Table('subitems', meta, Column('sub_id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('items.item_id')), Column('value', String(100))) @@ -34,12 +33,12 @@ def load(): z = ((x-1) * DIVISOR) + y l.append({'sub_id':z,'value':'this is item #%d' % z, 'parent_id':x}) #print l - subitems.insert().execute(*l) + subitems.insert().execute(*l) -@profiling.profiled('masseagerload', always=True) +@profiling.profiled('masseagerload', always=True, sort=['cumulative']) def masseagerload(session): query = session.query(Item) - l = query.select() + l = query.all() print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems" def all(): diff --git a/test/perf/massload.py b/test/perf/massload.py index 92cf0fe920..8343330838 100644 --- a/test/perf/massload.py +++ b/test/perf/massload.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() import time #import gc #import sqlalchemy.orm.attributes as attributes @@ -6,19 +6,22 @@ from sqlalchemy import * from sqlalchemy.orm import * from testlib import * -NUM = 2500 +""" +we are testing session.expunge() here, also that the attributes and unitofwork +packages dont keep dereferenced stuff hanging around. + +for best results, dont run with sqlite :memory: database, and keep an eye on +top while it runs """ -we are testing session.expunge() here, also that the attributes and unitofwork packages dont keep dereferenced -stuff hanging around. -for best results, dont run with sqlite :memory: database, and keep an eye on top while it runs""" +NUM = 2500 -class LoadTest(AssertMixin): +class LoadTest(TestBase, AssertsExecutionResults): def setUpAll(self): global items, meta - meta = MetaData(testbase.db) - items = Table('items', meta, + meta = MetaData(testing.db) + items = Table('items', meta, Column('item_id', Integer, primary_key=True), Column('value', String(100))) items.create() @@ -30,10 +33,10 @@ class LoadTest(AssertMixin): for y in range(x*500-500 + 1, x*500 + 1): l.append({'item_id':y, 'value':'this is item #%d' % y}) items.insert().execute(*l) - + def testload(self): class Item(object):pass - + m = mapper(Item, items) sess = create_session() now = time.time() @@ -41,7 +44,7 @@ class LoadTest(AssertMixin): for x in range (1,NUM/100): # this is not needed with cpython which clears non-circular refs immediately #gc.collect() - l = query.select(items.c.item_id.between(x*100 - 100 + 1, x*100)) + l = query.filter(items.c.item_id.between(x*100 - 100 + 1, x*100)).all() assert len(l) == 100 print "loaded ", len(l), " items " # modifying each object will insure that the objects get placed in the "dirty" list @@ -56,6 +59,7 @@ class LoadTest(AssertMixin): #objectstore.expunge(*l) total = time.time() -now print "total time ", total - + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/perf/massload2.py b/test/perf/massload2.py index d6424eb073..a3deb932f0 100644 --- a/test/perf/massload2.py +++ b/test/perf/massload2.py @@ -1,45 +1,38 @@ -import sys -sys.path.insert(0, './lib/') - -try: -# import sqlalchemy.mods.threadlocal - pass -except: - pass -from sqlalchemy import * -from testbase import Table, Column +import testenv; testenv.simple_setup() import time +from sqlalchemy import * +from sqlalchemy.orm import * -metadata = create_engine('sqlite://', echo=True) +metadata = MetaData(create_engine('sqlite://', echo=True)) -t1s = Table( 't1s', metadata, +t1s = Table( 't1s', metadata, Column( 'id', Integer, primary_key=True), Column('data', String(100)) - ) + ) -t2s = Table( 't2s', metadata, +t2s = Table( 't2s', metadata, Column( 'id', Integer, primary_key=True), Column( 't1id', Integer, ForeignKey("t1s.id"), nullable=True )) -t3s = Table( 't3s', metadata, +t3s = Table( 't3s', metadata, Column( 'id', Integer, primary_key=True), Column( 't2id', Integer, ForeignKey("t2s.id"), nullable=True )) -t4s = Table( 't4s', metadata, - Column( 'id', Integer, primary_key=True), +t4s = Table( 't4s', metadata, + Column( 'id', Integer, primary_key=True), Column( 't3id', Integer, ForeignKey("t3s.id"), nullable=True )) - + [t.create() for t in [t1s,t2s,t3s,t4s]] class T1( object ): pass class T2( object ): pass class T3( object ): pass -class T4( object ): pass +class T4( object ): pass mapper( T1, t1s ) -mapper( T2, t2s ) -mapper( T3, t3s ) -mapper( T4, t4s ) +mapper( T2, t2s ) +mapper( T3, t3s ) +mapper( T4, t4s ) cascade = "all, delete-orphan" use_backref = True @@ -55,22 +48,22 @@ else: now = time.time() print "start" -sess = create_session() +sess = create_session() o1 = T1() -sess.save(o1) +sess.save(o1) for i2 in range(10): o2 = T2() o1.t2s.append( o2 ) - + for i3 in range( 10 ): o3 = T3() o2.t3s.append( o3 ) - + for i4 in range( 10 ): o3.t4s.append ( T4() ) print i2, i3, i4 -print len([s for s in sess]) +print len([s for s in sess]) print "flushing" sess.flush() total = time.time() - now diff --git a/test/perf/masssave.py b/test/perf/masssave.py index dd03f39629..bf65c8fdf7 100644 --- a/test/perf/masssave.py +++ b/test/perf/masssave.py @@ -1,29 +1,29 @@ -import testbase +import testenv; testenv.configure_for_tests() import types from sqlalchemy import * from sqlalchemy.orm import * from testlib import * -NUM = 250000 +NUM = 2500 -class SaveTest(AssertMixin): +class SaveTest(TestBase, AssertsExecutionResults): def setUpAll(self): global items, metadata - metadata = MetaData(testbase.db) - items = Table('items', metadata, + metadata = MetaData(testing.db) + items = Table('items', metadata, Column('item_id', Integer, primary_key=True), Column('value', String(100))) items.create() def tearDownAll(self): clear_mappers() metadata.drop_all() - + def testsave(self): class Item(object):pass - + m = mapper(Item, items) - + for x in range(0,NUM/50): sess = create_session() query = sess.query(Item) @@ -48,5 +48,7 @@ class SaveTest(AssertMixin): rep.sort(sorter) for x in rep[0:30]: print x + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/perf/objselectspeed.py b/test/perf/objselectspeed.py new file mode 100644 index 0000000000..896fd4c494 --- /dev/null +++ b/test/perf/objselectspeed.py @@ -0,0 +1,110 @@ +import testenv; testenv.simple_setup() +import time, gc, resource +from sqlalchemy import * +from sqlalchemy.orm import * + + +db = create_engine('sqlite://') +metadata = MetaData(db) +Person_table = Table('Person', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(40)), + Column('sex', Integer), + Column('age', Integer)) + +class RawPerson(object): pass +class Person(object): pass +mapper(Person, Person_table) +compile_mappers() + +def setup(): + metadata.create_all() + i = Person_table.insert() + data = [{'name':'John Doe','sex':1,'age':35}] * 100 + for j in xrange(500): + i.execute(data) + print "Inserted 50,000 rows" + +def sqlite_select(entity_cls): + conn = db.connect().connection + cr = conn.cursor() + cr.execute("SELECT id, name, sex, age FROM Person") + people = [] + for row in cr.fetchall(): + person = entity_cls() + person.id = row[0] + person.name = row[1] + person.sex = row[2] + person.age = row[3] + people.append(person) + cr.close() + conn.close() + +def sql_select(entity_cls): + people = [] + for row in Person_table.select().execute().fetchall(): + person = entity_cls() + person.id = row.id + person.name = row.name + person.sex = row.sex + person.age = row.age + people.append(person) + +def orm_select(): + session = create_session() + people = session.query(Person).all() + +def all(): + setup() + try: + t, t2 = 0, 0 + def usage(label): + now = resource.getrusage(resource.RUSAGE_SELF) + print "%s: %0.3fs real, %0.3fs user, %0.3fs sys" % ( + label, t2 - t, + now.ru_utime - usage.last.ru_utime, + now.ru_stime - usage.last.ru_stime) + usage.snap(now) + usage.snap = lambda stats=None: setattr( + usage, 'last', stats or resource.getrusage(resource.RUSAGE_SELF)) + + gc.collect() + usage.snap() + t = time.clock() + sqlite_select(RawPerson) + t2 = time.clock() + usage('sqlite select/native') + + gc.collect() + usage.snap() + t = time.clock() + sqlite_select(Person) + t2 = time.clock() + usage('sqlite select/instrumented') + + gc.collect() + usage.snap() + t = time.clock() + sql_select(RawPerson) + t2 = time.clock() + usage('sqlalchemy.sql select/native') + + gc.collect() + usage.snap() + t = time.clock() + sql_select(Person) + t2 = time.clock() + usage('sqlalchemy.sql select/instrumented') + + gc.collect() + usage.snap() + t = time.clock() + orm_select() + t2 = time.clock() + usage('sqlalchemy.orm fetch') + finally: + metadata.drop_all() + + +if __name__ == '__main__': + all() diff --git a/test/perf/objupdatespeed.py b/test/perf/objupdatespeed.py new file mode 100644 index 0000000000..a49eb47245 --- /dev/null +++ b/test/perf/objupdatespeed.py @@ -0,0 +1,93 @@ +import testenv; testenv.configure_for_tests() +import time, gc, resource +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * + +NUM = 100 + +metadata = MetaData(testing.db) +Person_table = Table('Person', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(40)), + Column('sex', Integer), + Column('age', Integer)) + +Email_table = Table('Email', metadata, + Column('id', Integer, primary_key=True), + Column('person_id', Integer, ForeignKey('Person.id')), + Column('address', String(300))) + +class Person(object): + pass +class Email(object): + def __repr__(self): + return '' % (getattr(self, 'id', None), + getattr(self, 'address', None)) + +mapper(Person, Person_table, properties={ + 'emails': relation(Email, backref='owner', lazy=False) + }) +mapper(Email, Email_table) +compile_mappers() + +def setup(): + metadata.create_all() + i = Person_table.insert() + data = [{'name':'John Doe','sex':1,'age':35}] * NUM + i.execute(data) + + i = Email_table.insert() + for j in xrange(1, NUM + 1): + i.execute(address='foo@bar', person_id=j) + if j % 2: + i.execute(address='baz@quux', person_id=j) + + print "Inserted %d rows." % (NUM + NUM + (NUM // 2)) + +def orm_select(session): + return session.query(Person).all() + +@profiling.profiled('update_and_flush') +def update_and_flush(session, people): + for p in people: + p.name = 'Exene Cervenka' + p.sex = 2 + p.emails[0].address = 'hoho@lala' + session.flush() + +def all(): + setup() + try: + t, t2 = 0, 0 + def usage(label): + now = resource.getrusage(resource.RUSAGE_SELF) + print "%s: %0.3fs real, %0.3fs user, %0.3fs sys" % ( + label, t2 - t, + now.ru_utime - usage.last.ru_utime, + now.ru_stime - usage.last.ru_stime) + usage.snap(now) + usage.snap = lambda stats=None: setattr( + usage, 'last', stats or resource.getrusage(resource.RUSAGE_SELF)) + + session = create_session() + + gc.collect() + usage.snap() + t = time.clock() + people = orm_select(session) + t2 = time.clock() + usage('load objects') + + gc.collect() + usage.snap() + t = time.clock() + update_and_flush(session, people) + t2 = time.clock() + usage('update and flush') + finally: + metadata.drop_all() + + +if __name__ == '__main__': + all() diff --git a/test/perf/ormsession.py b/test/perf/ormsession.py index a9d310ef68..b0187a7871 100644 --- a/test/perf/ormsession.py +++ b/test/perf/ormsession.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() import time from datetime import datetime @@ -78,7 +78,7 @@ def insert_data(): q_sub_per_item = 10 q_customers = 1000 - con = testbase.db.connect() + con = testing.db.connect() transaction = con.begin() data, subdata = [], [] @@ -146,8 +146,8 @@ def insert_data(): def run_queries(): session = create_session() # no explicit transaction here. - - # build a report of summarizing the last 50 purchases and + + # build a report of summarizing the last 50 purchases and # the top 20 items from all purchases q = session.query(Purchase). \ @@ -165,7 +165,7 @@ def run_queries(): for item in purchase.items: report.append(item.name) report.extend([s.name for s in item.subitems]) - + # mix a little low-level with orm # pull a report of the top 20 items of all time _item_id = purchaseitems.c.item_id @@ -174,7 +174,7 @@ def run_queries(): order_by=[desc(func.count(_item_id)), _item_id], limit=20) ids = [r.id for r in top_20_q.execute().fetchall()] - q2 = session.query(Item).filter(Item.id.in_(*ids)) + q2 = session.query(Item).filter(Item.id.in_(ids)) for num, item in enumerate(q2): report.append("number %s: %s" % (num + 1, item.name)) @@ -189,7 +189,7 @@ def create_purchase(): session.begin() customer = session.query(Customer).get(customer_id) - items = session.query(Item).filter(Item.id.in_(*item_ids)) + items = session.query(Item).filter(Item.id.in_(item_ids)) purchase = Purchase() purchase.customer = customer @@ -212,7 +212,7 @@ def default(): @profiled('all') def main(): - metadata.bind = testbase.db + metadata.bind = testing.db try: define_tables() setup_mappers() diff --git a/test/perf/poolload.py b/test/perf/poolload.py index 1a2ff6978b..8d66da84f4 100644 --- a/test/perf/poolload.py +++ b/test/perf/poolload.py @@ -1,11 +1,11 @@ # load test of connection pool - -import testbase +import testenv; testenv.configure_for_tests() +import thread, time from sqlalchemy import * import sqlalchemy.pool as pool -import thread,time +from testlib import testing -db = create_engine(testbase.db.url, pool_timeout=30, echo_pool=True) +db = create_engine(testing.db.url, pool_timeout=30, echo_pool=True) metadata = MetaData(db) users_table = Table('users', metadata, @@ -30,8 +30,8 @@ def runfast(): # time.sleep(.005) # result.close() print "runfast cycle complete" - -#thread.start_new_thread(runslow, ()) + +#thread.start_new_thread(runslow, ()) for x in xrange(0,50): thread.start_new_thread(runfast, ()) diff --git a/test/perf/sessions.py b/test/perf/sessions.py new file mode 100644 index 0000000000..c5f7686531 --- /dev/null +++ b/test/perf/sessions.py @@ -0,0 +1,96 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * +from testlib import fixtures +import gc + +# in this test we are specifically looking for time spent in the attributes.InstanceState.__cleanup() method. + +ITERATIONS = 100 + +class SessionTest(TestBase, AssertsExecutionResults): + def setUpAll(self): + global t1, t2, metadata,T1, T2 + metadata = MetaData(testing.db) + t1 = Table('t1', metadata, + Column('c1', Integer, primary_key=True), + Column('c2', String(30))) + + t2 = Table('t2', metadata, + Column('c1', Integer, primary_key=True), + Column('c2', String(30)), + Column('t1id', Integer, ForeignKey('t1.c1')) + ) + + metadata.create_all() + + l = [] + for x in range(1,51): + l.append({'c2':'this is t1 #%d' % x}) + t1.insert().execute(*l) + for x in range(1, 51): + l = [] + for y in range(1, 100): + l.append({'c2':'this is t2 #%d' % y, 't1id':x}) + t2.insert().execute(*l) + + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't2s':relation(T2, backref='t1') + }) + mapper(T2, t2) + + def tearDownAll(self): + metadata.drop_all() + clear_mappers() + + @profiling.profiled('clean', report=True) + def test_session_clean(self): + for x in range(0, ITERATIONS): + sess = create_session() + t1s = sess.query(T1).filter(T1.c1.between(15, 48)).all() + for index in [2, 7, 12, 15, 18, 20]: + t1s[index].t2s + + sess.close() + del sess + gc.collect() + + @profiling.profiled('dirty', report=True) + def test_session_dirty(self): + for x in range(0, ITERATIONS): + sess = create_session() + t1s = sess.query(T1).filter(T1.c1.between(15, 48)).all() + + for index in [2, 7, 12, 15, 18, 20]: + t1s[index].c2 = 'this is some modified text' + for t2 in t1s[index].t2s: + t2.c2 = 'this is some modified text' + + del t1s + gc.collect() + + sess.close() + del sess + gc.collect() + + @profiling.profiled('noclose', report=True) + def test_session_noclose(self): + for x in range(0, ITERATIONS): + sess = create_session() + t1s = sess.query(T1).filter(T1.c1.between(15, 48)).all() + for index in [2, 7, 12, 15, 18, 20]: + t1s[index].t2s + + del sess + gc.collect() + + + +if __name__ == '__main__': + testenv.main() diff --git a/test/perf/threaded_compile.py b/test/perf/threaded_compile.py index 13ec31fd61..6809f2560d 100644 --- a/test/perf/threaded_compile.py +++ b/test/perf/threaded_compile.py @@ -1,21 +1,21 @@ """test that mapper compilation is threadsafe, including -when additional mappers are created while the existing +when additional mappers are created while the existing collection is being compiled.""" -import testbase +import testenv; testenv.simple_setup() from sqlalchemy import * from sqlalchemy.orm import * import thread, time from sqlalchemy.orm import mapperlib -from testlib import * + meta = MetaData('sqlite:///foo.db') -t1 = Table('t1', meta, +t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)) ) - + t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)), @@ -32,13 +32,13 @@ class T1(object): class T2(object): pass - + class FakeLock(object): def acquire(self):pass def release(self):pass # uncomment this to disable the mutex in mapper compilation; -# should produce thread collisions +# should produce thread collisions #mapperlib._COMPILE_MUTEX = FakeLock() def run1(): @@ -62,7 +62,7 @@ def run3(): class_mapper(Foo).compile() foo() time.sleep(.05) - + mapper(T1, t1, properties={'t2':relation(T2, backref="t1")}) mapper(T2, t2) print "START" @@ -74,4 +74,3 @@ for j in range(0, 5): thread.start_new_thread(run3, ()) print "WAIT" time.sleep(5) - diff --git a/test/perf/wsgi.py b/test/perf/wsgi.py index d22eeb76a0..6fc8149bcd 100644 --- a/test/perf/wsgi.py +++ b/test/perf/wsgi.py @@ -1,7 +1,7 @@ #!/usr/bin/python """Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop.""" -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * import thread @@ -14,8 +14,8 @@ logging.basicConfig() logging.getLogger('sqlalchemy.pool').setLevel(logging.INFO) threadids = set() -meta = MetaData(testbase.db) -foo = Table('foo', meta, +meta = MetaData(testing.db) +foo = Table('foo', meta, Column('id', Integer, primary_key=True), Column('data', String(30))) class Foo(object): @@ -41,7 +41,7 @@ def serve(environ, start_response): " total threads ", len(threadids)) return [str("\n".join([x.data for x in l]))] - + if __name__ == '__main__': from wsgiref import simple_server try: @@ -51,5 +51,3 @@ if __name__ == '__main__': server.serve_forever() finally: meta.drop_all() - - diff --git a/test/profiling/__init__.py b/test/profiling/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/profiling/alltests.py b/test/profiling/alltests.py new file mode 100644 index 0000000000..61e8a4a14c --- /dev/null +++ b/test/profiling/alltests.py @@ -0,0 +1,21 @@ +import testenv; testenv.configure_for_tests() +import unittest + + +def suite(): + modules_to_test = ( + 'profiling.compiler', + 'profiling.pool', + 'profiling.zoomark', + ) + alltests = unittest.TestSuite() + for name in modules_to_test: + mod = __import__(name) + for token in name.split('.')[1:]: + mod = getattr(mod, token) + alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) + return alltests + + +if __name__ == '__main__': + testenv.main(suite()) diff --git a/test/profiling/compiler.py b/test/profiling/compiler.py new file mode 100644 index 0000000000..4e1111aa2a --- /dev/null +++ b/test/profiling/compiler.py @@ -0,0 +1,33 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from testlib import * + + +class CompileTest(TestBase, AssertsExecutionResults): + def setUpAll(self): + global t1, t2, metadata + metadata = MetaData() + t1 = Table('t1', metadata, + Column('c1', Integer, primary_key=True), + Column('c2', String(30))) + + t2 = Table('t2', metadata, + Column('c1', Integer, primary_key=True), + Column('c2', String(30))) + + @profiling.function_call_count(74, {'2.3': 44, '2.4': 42}) + def test_insert(self): + t1.insert().compile() + + @profiling.function_call_count(75, {'2.3': 47, '2.4': 42}) + def test_update(self): + t1.update().compile() + + @profiling.function_call_count(228, versions={'2.3': 153, '2.4':116}) + def test_select(self): + s = select([t1], t1.c.c2==t2.c.c1) + s.compile() + + +if __name__ == '__main__': + testenv.main() diff --git a/test/profiling/pool.py b/test/profiling/pool.py new file mode 100644 index 0000000000..4b146fbabd --- /dev/null +++ b/test/profiling/pool.py @@ -0,0 +1,49 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from testlib import * +from sqlalchemy.pool import QueuePool + + +class QueuePoolTest(TestBase, AssertsExecutionResults): + class Connection(object): + def close(self): + pass + + def setUp(self): + global pool + pool = QueuePool(creator=self.Connection, + pool_size=3, max_overflow=-1, + use_threadlocal=True) + + # the WeakValueDictionary used for the pool's "threadlocal" idea adds 1-6 + # method calls to each of these. however its just a lot easier stability + # wise than dealing with a strongly referencing dict of weakrefs. + # [ticket:754] immediately got opened when we tried a dict of weakrefs, + # and though the solution there is simple, it still doesn't solve the + # issue of "dead" weakrefs sitting in the dict taking up space + + @profiling.function_call_count(63, {'2.3': 42, '2.4': 43}) + def test_first_connect(self): + conn = pool.connect() + + def test_second_connect(self): + conn = pool.connect() + conn.close() + + @profiling.function_call_count(39, {'2.3': 26, '2.4': 26}) + def go(): + conn2 = pool.connect() + return conn2 + c2 = go() + + def test_second_samethread_connect(self): + conn = pool.connect() + + @profiling.function_call_count(7, {'2.3': 4, '2.4': 4}) + def go(): + return pool.connect() + c2 = go() + + +if __name__ == '__main__': + testenv.main() diff --git a/test/profiling/zoomark.py b/test/profiling/zoomark.py new file mode 100644 index 0000000000..0994b5d4be --- /dev/null +++ b/test/profiling/zoomark.py @@ -0,0 +1,360 @@ +"""Benchmark for SQLAlchemy. + +An adaptation of Robert Brewers' ZooMark speed tests. +""" + +import datetime +import sys +import time +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from testlib import * + +ITERATIONS = 1 + +dbapi_session = engines.ReplayableSession() +metadata = None + +class ZooMarkTest(TestBase): + """Runs the ZooMark and squawks if method counts vary from the norm. + + Each test has an associated `call_range`, the total number of accepted + function calls made during the test. The count can vary between Python + 2.4 and 2.5. + + Unlike a unit test, this is a ordered collection of steps. Running + components individually will fail. + + """ + + __only_on__ = 'postgres' + __skip_if__ = ((lambda: sys.version_info < (2, 4)), ) + + def test_baseline_0_setup(self): + global metadata + + creator = testing.db.pool._creator + recorder = lambda: dbapi_session.recorder(creator()) + engine = engines.testing_engine(options={'creator':recorder}) + metadata = MetaData(engine) + + def test_baseline_1_create_tables(self): + Zoo = Table('Zoo', metadata, + Column('ID', Integer, Sequence('zoo_id_seq'), + primary_key=True, index=True), + Column('Name', Unicode(255)), + Column('Founded', Date), + Column('Opens', Time), + Column('LastEscape', DateTime), + Column('Admission', Float), + ) + + Animal = Table('Animal', metadata, + Column('ID', Integer, Sequence('animal_id_seq'), + primary_key=True), + Column('ZooID', Integer, ForeignKey('Zoo.ID'), + index=True), + Column('Name', Unicode(100)), + Column('Species', Unicode(100)), + Column('Legs', Integer, default=4), + Column('LastEscape', DateTime), + Column('Lifespan', Float(4)), + Column('MotherID', Integer, ForeignKey('Animal.ID')), + Column('PreferredFoodID', Integer), + Column('AlternateFoodID', Integer), + ) + metadata.create_all() + + def test_baseline_1a_populate(self): + Zoo = metadata.tables['Zoo'] + Animal = metadata.tables['Animal'] + + wap = Zoo.insert().execute(Name=u'Wild Animal Park', + Founded=datetime.date(2000, 1, 1), + # 59 can give rounding errors with divmod, which + # AdapterFromADO needs to correct. + Opens=datetime.time(8, 15, 59), + LastEscape=datetime.datetime(2004, 7, 29, 5, 6, 7), + Admission=4.95, + ).last_inserted_ids()[0] + + sdz = Zoo.insert().execute(Name =u'San Diego Zoo', + Founded = datetime.date(1935, 9, 13), + Opens = datetime.time(9, 0, 0), + Admission = 0, + ).last_inserted_ids()[0] + + Zoo.insert().execute( + Name = u'Montr\xe9al Biod\xf4me', + Founded = datetime.date(1992, 6, 19), + Opens = datetime.time(9, 0, 0), + Admission = 11.75, + ) + + seaworld = Zoo.insert().execute( + Name =u'Sea_World', Admission = 60).last_inserted_ids()[0] + + # Let's add a crazy futuristic Zoo to test large date values. + lp = Zoo.insert().execute(Name =u'Luna Park', + Founded = datetime.date(2072, 7, 17), + Opens = datetime.time(0, 0, 0), + Admission = 134.95, + ).last_inserted_ids()[0] + + # Animals + leopardid = Animal.insert().execute(Species=u'Leopard', Lifespan=73.5, + ).last_inserted_ids()[0] + Animal.update(Animal.c.ID==leopardid).execute(ZooID=wap, + LastEscape=datetime.datetime(2004, 12, 21, 8, 15, 0, 999907)) + + lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).last_inserted_ids()[0] + Animal.insert().execute(Species=u'Slug', Legs=1, Lifespan=.75) + + tiger = Animal.insert().execute(Species=u'Tiger', ZooID=sdz + ).last_inserted_ids()[0] + + # Override Legs.default with itself just to make sure it works. + Animal.insert().execute(Species=u'Bear', Legs=4) + Animal.insert().execute(Species=u'Ostrich', Legs=2, Lifespan=103.2) + Animal.insert().execute(Species=u'Centipede', Legs=100) + + emp = Animal.insert().execute(Species=u'Emperor Penguin', Legs=2, + ZooID=seaworld).last_inserted_ids()[0] + adelie = Animal.insert().execute(Species=u'Adelie Penguin', Legs=2, + ZooID=seaworld).last_inserted_ids()[0] + + Animal.insert().execute(Species=u'Millipede', Legs=1000000, ZooID=sdz) + + # Add a mother and child to test relationships + bai_yun = Animal.insert().execute(Species=u'Ape', Name=u'Bai Yun', + Legs=2).last_inserted_ids()[0] + Animal.insert().execute(Species=u'Ape', Name=u'Hua Mei', Legs=2, + MotherID=bai_yun) + + def test_baseline_2_insert(self): + Animal = metadata.tables['Animal'] + i = Animal.insert() + for x in xrange(ITERATIONS): + tick = i.execute(Species=u'Tick', Name=u'Tick %d' % x, Legs=8) + + def test_baseline_3_properties(self): + Zoo = metadata.tables['Zoo'] + Animal = metadata.tables['Animal'] + + def fullobject(select): + """Iterate over the full result row.""" + return list(select.execute().fetchone()) + + for x in xrange(ITERATIONS): + # Zoos + WAP = fullobject(Zoo.select(Zoo.c.Name==u'Wild Animal Park')) + SDZ = fullobject(Zoo.select(Zoo.c.Founded==datetime.date(1935, 9, 13))) + Biodome = fullobject(Zoo.select(Zoo.c.Name==u'Montr\xe9al Biod\xf4me')) + seaworld = fullobject(Zoo.select(Zoo.c.Admission == float(60))) + + # Animals + leopard = fullobject(Animal.select(Animal.c.Species ==u'Leopard')) + ostrich = fullobject(Animal.select(Animal.c.Species==u'Ostrich')) + millipede = fullobject(Animal.select(Animal.c.Legs==1000000)) + ticks = fullobject(Animal.select(Animal.c.Species==u'Tick')) + + def test_baseline_4_expressions(self): + Zoo = metadata.tables['Zoo'] + Animal = metadata.tables['Animal'] + + def fulltable(select): + """Iterate over the full result table.""" + return [list(row) for row in select.execute().fetchall()] + + for x in xrange(ITERATIONS): + assert len(fulltable(Zoo.select())) == 5 + assert len(fulltable(Animal.select())) == ITERATIONS + 12 + assert len(fulltable(Animal.select(Animal.c.Legs==4))) == 4 + assert len(fulltable(Animal.select(Animal.c.Legs == 2))) == 5 + assert len(fulltable(Animal.select(and_(Animal.c.Legs >= 2, Animal.c.Legs < 20) + ))) == ITERATIONS + 9 + assert len(fulltable(Animal.select(Animal.c.Legs > 10))) == 2 + assert len(fulltable(Animal.select(Animal.c.Lifespan > 70))) == 2 + assert len(fulltable(Animal.select(Animal.c.Species.startswith(u'L')))) == 2 + assert len(fulltable(Animal.select(Animal.c.Species.endswith(u'pede')))) == 2 + + assert len(fulltable(Animal.select(Animal.c.LastEscape != None))) == 1 + assert len(fulltable(Animal.select(None == Animal.c.LastEscape + ))) == ITERATIONS + 11 + + # In operator (containedby) + assert len(fulltable(Animal.select(Animal.c.Species.like(u'%pede%')))) == 2 + assert len(fulltable(Animal.select(Animal.c.Species.in_([u'Lion', u'Tiger', u'Bear'])))) == 3 + + # Try In with cell references + class thing(object): pass + pet, pet2 = thing(), thing() + pet.Name, pet2.Name =u'Slug', u'Ostrich' + assert len(fulltable(Animal.select(Animal.c.Species.in_([pet.Name, pet2.Name])))) == 2 + + # logic and other functions + assert len(fulltable(Animal.select(Animal.c.Species.like(u'Slug')))) == 1 + assert len(fulltable(Animal.select(Animal.c.Species.like(u'%pede%')))) == 2 + name =u'Lion' + assert len(fulltable(Animal.select(func.length(Animal.c.Species) == len(name) + ))) == ITERATIONS + 3 + + assert len(fulltable(Animal.select(Animal.c.Species.like(u'%i%') + ))) == ITERATIONS + 7 + + # Test now(), today(), year(), month(), day() + assert len(fulltable(Zoo.select(Zoo.c.Founded != None + and Zoo.c.Founded < func.current_timestamp(_type=Date)))) == 3 + assert len(fulltable(Animal.select(Animal.c.LastEscape == func.current_timestamp(_type=Date)))) == 0 + assert len(fulltable(Animal.select(func.date_part('year', Animal.c.LastEscape) == 2004))) == 1 + assert len(fulltable(Animal.select(func.date_part('month', Animal.c.LastEscape) == 12))) == 1 + assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1 + + def test_baseline_5_aggregates(self): + Animal = metadata.tables['Animal'] + Zoo = metadata.tables['Zoo'] + + for x in xrange(ITERATIONS): + # views + view = select([Animal.c.Legs]).execute().fetchall() + legs = [x[0] for x in view] + legs.sort() + + expected = {'Leopard': 73.5, + 'Slug': .75, + 'Tiger': None, + 'Lion': None, + 'Bear': None, + 'Ostrich': 103.2, + 'Centipede': None, + 'Emperor Penguin': None, + 'Adelie Penguin': None, + 'Millipede': None, + 'Ape': None, + 'Tick': None, + } + for species, lifespan in select([Animal.c.Species, Animal.c.Lifespan] + ).execute().fetchall(): + assert lifespan == expected[species] + + expected = [u'Montr\xe9al Biod\xf4me', 'Wild Animal Park'] + e = select([Zoo.c.Name], + and_(Zoo.c.Founded != None, + Zoo.c.Founded <= func.current_timestamp(), + Zoo.c.Founded >= datetime.date(1990, 1, 1))) + values = [val[0] for val in e.execute().fetchall()] + assert set(values) == set(expected) + + # distinct + legs = [x[0] for x in + select([Animal.c.Legs], distinct=True).execute().fetchall()] + legs.sort() + + def test_baseline_6_editing(self): + Zoo = metadata.tables['Zoo'] + + for x in xrange(ITERATIONS): + # Edit + SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone() + Zoo.update(Zoo.c.ID==SDZ['ID']).execute( + Name=u'The San Diego Zoo', + Founded = datetime.date(1900, 1, 1), + Opens = datetime.time(7, 30, 0), + Admission = "35.00") + + # Test edits + SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().fetchone() + assert SDZ['Founded'] == datetime.date(1900, 1, 1), SDZ['Founded'] + + # Change it back + Zoo.update(Zoo.c.ID==SDZ['ID']).execute( + Name =u'San Diego Zoo', + Founded = datetime.date(1935, 9, 13), + Opens = datetime.time(9, 0, 0), + Admission = "0") + + # Test re-edits + SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone() + assert SDZ['Founded'] == datetime.date(1935, 9, 13) + + def test_baseline_7_multiview(self): + Zoo = metadata.tables['Zoo'] + Animal = metadata.tables['Animal'] + + def fulltable(select): + """Iterate over the full result table.""" + return [list(row) for row in select.execute().fetchall()] + + for x in xrange(ITERATIONS): + za = fulltable(select([Zoo.c.ID] + list(Animal.c), + Zoo.c.Name ==u'San Diego Zoo', + from_obj = [join(Zoo, Animal)])) + + SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo') + + e = fulltable(select([Zoo.c.ID, Animal.c.ID], + and_(Zoo.c.Name==u'San Diego Zoo', + Animal.c.Species==u'Leopard'), + from_obj = [join(Zoo, Animal)])) + + # Now try the same query with INNER, LEFT, and RIGHT JOINs. + e = fulltable(select([Zoo.c.Name, Animal.c.Species], + from_obj=[join(Zoo, Animal)])) + e = fulltable(select([Zoo.c.Name, Animal.c.Species], + from_obj=[outerjoin(Zoo, Animal)])) + e = fulltable(select([Zoo.c.Name, Animal.c.Species], + from_obj=[outerjoin(Animal, Zoo)])) + + def test_baseline_8_drop(self): + metadata.drop_all() + + # Now, run all of these tests again with the DB-API driver factored out: + # the ReplayableSession playback stands in for the database. + + # How awkward is this in a unittest framework? Very. + + def test_profile_0(self): + global metadata + + player = lambda: dbapi_session.player() + engine = create_engine('postgres:///', creator=player) + metadata = MetaData(engine) + + @profiling.function_call_count(3230, {'2.4': 1796}) + def test_profile_1_create_tables(self): + self.test_baseline_1_create_tables() + + @profiling.function_call_count(6064, {'2.4': 3635}) + def test_profile_1a_populate(self): + self.test_baseline_1a_populate() + + @profiling.function_call_count(339, {'2.4': 195}) + def test_profile_2_insert(self): + self.test_baseline_2_insert() + + @profiling.function_call_count(4923, {'2.4': 2557}) + def test_profile_3_properties(self): + self.test_baseline_3_properties() + + @profiling.function_call_count(18119, {'2.4': 10549}) + def test_profile_4_expressions(self): + self.test_baseline_4_expressions() + + @profiling.function_call_count(1617, {'2.4': 1032}) + def test_profile_5_aggregates(self): + self.test_baseline_5_aggregates() + + @profiling.function_call_count(1988, {'2.4': 1048}) + def test_profile_6_editing(self): + self.test_baseline_6_editing() + + @profiling.function_call_count(3614, {'2.4': 2198}) + def test_profile_7_multiview(self): + self.test_baseline_7_multiview() + + def test_profile_8_drop(self): + self.test_baseline_8_drop() + + +if __name__ == '__main__': + testenv.main() diff --git a/test/sql/alltests.py b/test/sql/alltests.py index a669a25f2d..173b046327 100644 --- a/test/sql/alltests.py +++ b/test/sql/alltests.py @@ -1,26 +1,28 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest def suite(): modules_to_test = ( 'sql.testtypes', + 'sql.columns', 'sql.constraints', 'sql.generative', - + # SQL syntax 'sql.select', 'sql.selectable', - 'sql.case_statement', + 'sql.case_statement', 'sql.labels', 'sql.unicode', - + # assorted round-trip tests + 'sql.functions', 'sql.query', 'sql.quote', 'sql.rowcount', - + # defaults, sequences (postgres/oracle) 'sql.defaults', ) @@ -33,4 +35,4 @@ def suite(): return alltests if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index 493545b228..6aecefd3c3 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -1,37 +1,40 @@ -import testbase +import testenv; testenv.configure_for_tests() import sys from sqlalchemy import * from testlib import * +from sqlalchemy import util, exceptions +from sqlalchemy.sql import table, column -class CaseTest(PersistTest): +class CaseTest(TestBase, AssertsCompiledSQL): def setUpAll(self): - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) global info_table info_table = Table('infos', metadata, - Column('pk', Integer, primary_key=True), - Column('info', String(30))) + Column('pk', Integer, primary_key=True), + Column('info', String(30))) info_table.create() info_table.insert().execute( - {'pk':1, 'info':'pk_1_data'}, - {'pk':2, 'info':'pk_2_data'}, - {'pk':3, 'info':'pk_3_data'}, - {'pk':4, 'info':'pk_4_data'}, - {'pk':5, 'info':'pk_5_data'}, - {'pk':6, 'info':'pk_6_data'}) + {'pk':1, 'info':'pk_1_data'}, + {'pk':2, 'info':'pk_2_data'}, + {'pk':3, 'info':'pk_3_data'}, + {'pk':4, 'info':'pk_4_data'}, + {'pk':5, 'info':'pk_5_data'}, + {'pk':6, 'info':'pk_6_data'}) def tearDownAll(self): info_table.drop() - + + @testing.fails_on('maxdb') def testcase(self): inner = select([case([ - [info_table.c.pk < 3, - literal('lessthan3', type_=String)], - [and_(info_table.c.pk >= 3, info_table.c.pk < 7), - literal('gt3', type_=String)]]).label('x'), - info_table.c.pk, info_table.c.info], + [info_table.c.pk < 3, + 'lessthan3'], + [and_(info_table.c.pk >= 3, info_table.c.pk < 7), + 'gt3']]).label('x'), + info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') inner_result = inner.execute().fetchall() @@ -66,12 +69,12 @@ class CaseTest(PersistTest): ] w_else = select([case([ - [info_table.c.pk < 3, - literal(3, type_=Integer)], - [and_(info_table.c.pk >= 3, info_table.c.pk < 6), - literal(6, type_=Integer)]], + [info_table.c.pk < 3, + 3], + [and_(info_table.c.pk >= 3, info_table.c.pk < 6), + 6]], else_ = 0).label('x'), - info_table.c.pk, info_table.c.info], + info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') else_result = w_else.execute().fetchall() @@ -85,5 +88,47 @@ class CaseTest(PersistTest): (0, 6, 'pk_6_data') ] + def test_literal_interpretation(self): + t = table('test', column('col1')) + + self.assertRaises(exceptions.ArgumentError, case, [("x", "y")]) + + self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END") + self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END") + + + @testing.fails_on('maxdb') + def testcase_with_dict(self): + query = select([case({ + info_table.c.pk < 3: 'lessthan3', + info_table.c.pk >= 3: 'gt3', + }, else_='other'), + info_table.c.pk, info_table.c.info + ], + from_obj=[info_table]) + assert query.execute().fetchall() == [ + ('lessthan3', 1, 'pk_1_data'), + ('lessthan3', 2, 'pk_2_data'), + ('gt3', 3, 'pk_3_data'), + ('gt3', 4, 'pk_4_data'), + ('gt3', 5, 'pk_5_data'), + ('gt3', 6, 'pk_6_data') + ] + + simple_query = select([case({ + 1: 'one', + 2: 'two', + }, value=info_table.c.pk, else_='other'), + info_table.c.pk + ], + whereclause=info_table.c.pk < 4, + from_obj=[info_table]) + + assert simple_query.execute().fetchall() == [ + ('one', 1), + ('two', 2), + ('other', 3), + ] + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/sql/columns.py b/test/sql/columns.py new file mode 100644 index 0000000000..76bf9b389c --- /dev/null +++ b/test/sql/columns.py @@ -0,0 +1,60 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy import exceptions, sql +from testlib import * +from sqlalchemy import Table, Column # don't use testlib's wrappers + + +class ColumnDefinitionTest(TestBase): + """Test Column() construction.""" + + # flesh this out with explicit coverage... + + def columns(self): + return [ Column(), + Column('b'), + Column(Integer), + Column('d', Integer), + Column(name='e'), + Column(type_=Integer), + Column(Integer()), + Column('h', Integer()), + Column(type_=Integer()) ] + + def test_basic(self): + c = self.columns() + + for i, v in ((0, 'a'), (2, 'c'), (5, 'f'), (6, 'g'), (8, 'i')): + c[i].name = v + c[i].key = v + del i, v + + tbl = Table('table', MetaData(), *c) + + for i, col in enumerate(tbl.c): + assert col.name == c[i].name + + def test_incomplete(self): + c = self.columns() + + self.assertRaises(exceptions.ArgumentError, Table, 't', MetaData(), *c) + + def test_incomplete_key(self): + c = Column(Integer) + assert c.name is None + assert c.key is None + + c.name = 'named' + t = Table('t', MetaData(), c) + + assert c.name == 'named' + assert c.name == c.key + + + def test_bogus(self): + self.assertRaises(exceptions.ArgumentError, Column, 'foo', name='bar') + self.assertRaises(exceptions.ArgumentError, Column, 'foo', Integer, + type_=Integer()) + +if __name__ == "__main__": + testenv.main() diff --git a/test/sql/constraints.py b/test/sql/constraints.py index 3120185d59..2908e07da9 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -1,18 +1,20 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import exceptions from testlib import * +from testlib import config, engines + +class ConstraintTest(TestBase, AssertsExecutionResults): -class ConstraintTest(AssertMixin): - def setUp(self): global metadata - metadata = MetaData(testbase.db) - + metadata = MetaData(testing.db) + def tearDown(self): metadata.drop_all() - + def test_constraint(self): - employees = Table('employees', metadata, + employees = Table('employees', metadata, Column('id', Integer), Column('soc', String(40)), Column('name', String(30)), @@ -29,7 +31,7 @@ class ConstraintTest(AssertMixin): metadata.create_all() def test_circular_constraint(self): - a = Table("a", metadata, + a = Table("a", metadata, Column('id', Integer, primary_key=True), Column('bid', Integer), ForeignKeyConstraint(["bid"], ["b.id"], name="afk") @@ -42,7 +44,7 @@ class ConstraintTest(AssertMixin): metadata.create_all() def test_circular_constraint_2(self): - a = Table("a", metadata, + a = Table("a", metadata, Column('id', Integer, primary_key=True), Column('bid', Integer, ForeignKey("b.id")), ) @@ -51,15 +53,15 @@ class ConstraintTest(AssertMixin): Column("aid", Integer, ForeignKey("a.id", use_alter=True, name="bfk")), ) metadata.create_all() - + @testing.unsupported('mysql') def test_check_constraint(self): - foo = Table('foo', metadata, + foo = Table('foo', metadata, Column('id', Integer, primary_key=True), Column('x', Integer), Column('y', Integer), CheckConstraint('x>y')) - bar = Table('bar', metadata, + bar = Table('bar', metadata, Column('id', Integer, primary_key=True), Column('x', Integer, CheckConstraint('x>7')), Column('z', Integer) @@ -79,7 +81,7 @@ class ConstraintTest(AssertMixin): assert False except exceptions.SQLError: assert True - + def test_unique_constraint(self): foo = Table('foo', metadata, Column('id', Integer, primary_key=True), @@ -105,7 +107,7 @@ class ConstraintTest(AssertMixin): assert False except exceptions.SQLError: assert True - + def test_index_create(self): employees = Table('employees', metadata, Column('id', Integer, primary_key=True), @@ -113,14 +115,14 @@ class ConstraintTest(AssertMixin): Column('last_name', String(30)), Column('email_address', String(30))) employees.create() - + i = Index('employee_name_index', employees.c.last_name, employees.c.first_name) i.create() assert i in employees.indexes - + i2 = Index('employee_email_index', - employees.c.email_address, unique=True) + employees.c.email_address, unique=True) i2.create() assert i2 in employees.indexes @@ -133,13 +135,13 @@ class ConstraintTest(AssertMixin): Column('emailAddress', String(30))) employees.create() - + i = Index('employeeNameIndex', employees.c.lastName, employees.c.firstName) i.create() - + i = Index('employeeEmailIndex', - employees.c.emailAddress, unique=True) + employees.c.emailAddress, unique=True) i.create() # Check that the table is useable. This is mostly for pg, @@ -162,7 +164,7 @@ class ConstraintTest(AssertMixin): Index('sport_announcer', events.c.sport, events.c.announcer, unique=True) Index('idx_winners', events.c.winner) - + index_names = [ ix.name for ix in events.indexes ] assert 'ix_events_name' in index_names assert 'ix_events_location' in index_names @@ -171,34 +173,138 @@ class ConstraintTest(AssertMixin): assert len(index_names) == 4 capt = [] - connection = testbase.db.connect() + connection = testing.db.connect() # TODO: hacky, put a real connection proxy in - ex = connection._Connection__execute + ex = connection._Connection__execute_raw def proxy(context): capt.append(context.statement) capt.append(repr(context.parameters)) ex(context) - connection._Connection__execute = proxy - schemagen = testbase.db.dialect.schemagenerator(connection) + connection._Connection__execute_raw = proxy + schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection) schemagen.traverse(events) - + assert capt[0].strip().startswith('CREATE TABLE events') - + s = set([capt[x].strip() for x in [2,4,6,8]]) - + assert s == set([ 'CREATE UNIQUE INDEX ix_events_name ON events (name)', 'CREATE INDEX ix_events_location ON events (location)', 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)', 'CREATE INDEX idx_winners ON events (winner)' ]) - + # verify that the table is functional events.insert().execute(id=1, name='hockey finals', location='rink', sport='hockey', announcer='some canadian', winner='sweden') ss = events.select().execute().fetchall() - -if __name__ == "__main__": - testbase.main() + +class ConstraintCompilationTest(TestBase, AssertsExecutionResults): + class accum(object): + def __init__(self): + self.statements = [] + def __call__(self, sql, *a, **kw): + self.statements.append(sql) + def __contains__(self, substring): + for s in self.statements: + if substring in s: + return True + return False + def __str__(self): + return '\n'.join([repr(x) for x in self.statements]) + def clear(self): + del self.statements[:] + + def setUp(self): + self.sql = self.accum() + opts = config.db_opts.copy() + opts['strategy'] = 'mock' + opts['executor'] = self.sql + self.engine = engines.testing_engine(options=opts) + + + def _test_deferrable(self, constraint_factory): + meta = MetaData(self.engine) + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer), + constraint_factory(deferrable=True)) + t.create() + assert 'DEFERRABLE' in self.sql, self.sql + assert 'NOT DEFERRABLE' not in self.sql, self.sql + self.sql.clear() + meta.clear() + + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer), + constraint_factory(deferrable=False)) + t.create() + assert 'NOT DEFERRABLE' in self.sql + self.sql.clear() + meta.clear() + + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer), + constraint_factory(deferrable=True, initially='IMMEDIATE')) + t.create() + assert 'NOT DEFERRABLE' not in self.sql + assert 'INITIALLY IMMEDIATE' in self.sql + self.sql.clear() + meta.clear() + + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer), + constraint_factory(deferrable=True, initially='DEFERRED')) + t.create() + + assert 'NOT DEFERRABLE' not in self.sql + assert 'INITIALLY DEFERRED' in self.sql, self.sql + + def test_deferrable_pk(self): + factory = lambda **kw: PrimaryKeyConstraint('a', **kw) + self._test_deferrable(factory) + + def test_deferrable_table_fk(self): + factory = lambda **kw: ForeignKeyConstraint(['b'], ['tbl.a'], **kw) + self._test_deferrable(factory) + + def test_deferrable_column_fk(self): + meta = MetaData(self.engine) + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer, + ForeignKey('tbl.a', deferrable=True, + initially='DEFERRED'))) + t.create() + assert 'DEFERRABLE' in self.sql, self.sql + assert 'INITIALLY DEFERRED' in self.sql, self.sql + + def test_deferrable_unique(self): + factory = lambda **kw: UniqueConstraint('b', **kw) + self._test_deferrable(factory) + + def test_deferrable_table_check(self): + factory = lambda **kw: CheckConstraint('a < b', **kw) + self._test_deferrable(factory) + + def test_deferrable_column_check(self): + meta = MetaData(self.engine) + t = Table('tbl', meta, + Column('a', Integer), + Column('b', Integer, + CheckConstraint('a < b', + deferrable=True, + initially='DEFERRED'))) + t.create() + assert 'DEFERRABLE' in self.sql, self.sql + assert 'INITIALLY DEFERRED' in self.sql, self.sql + + +if __name__ == "__main__": + testenv.main() diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 6c200232f2..22660c0607 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -1,83 +1,97 @@ -import testbase +import testenv; testenv.configure_for_tests() +import datetime from sqlalchemy import * -import sqlalchemy.util as util -import sqlalchemy.schema as schema +from sqlalchemy import exceptions, schema, util from sqlalchemy.orm import mapper, create_session from testlib import * -import datetime -class DefaultTest(PersistTest): + +class DefaultTest(TestBase): def setUpAll(self): - global t, f, f2, ts, currenttime, metadata + global t, f, f2, ts, currenttime, metadata, default_generator - db = testbase.db + db = testing.db metadata = MetaData(db) - x = {'x':50} - def mydefault(): - x['x'] += 1 - return x['x'] + default_generator = {'x':50} - def mydefault_with_ctx(ctx): - return ctx.compiled_parameters['col1'] + 10 + def mydefault(): + default_generator['x'] += 1 + return default_generator['x'] def myupdate_with_ctx(ctx): - return len(ctx.compiled_parameters['col2']) - - use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle' - is_oracle = db.engine.name == 'oracle' - + conn = ctx.connection + return conn.execute(select([text('13')])).scalar() + + def mydefault_using_connection(ctx): + conn = ctx.connection + try: + return conn.execute(select([text('12')])).scalar() + finally: + # ensure a "close()" on this connection does nothing, + # since its a "branched" connection + conn.close() + + use_function_defaults = testing.against('postgres', 'oracle') + is_oracle = testing.against('oracle') + # select "count(1)" returns different results on different DBs # also correct for "current_date" compatible as column default, value differences - currenttime = func.current_date(type_=Date, bind=db); + currenttime = func.current_date(type_=Date, bind=db) + if is_oracle: - ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar() - f = select([func.count(1) + 5], bind=db).scalar() - f2 = select([func.count(1) + 14], bind=db).scalar() + ts = db.scalar(select([func.trunc(func.sysdate(), literal_column("'DAY'"), type_=Date).label('today')])) + assert isinstance(ts, datetime.date) and not isinstance(ts, datetime.datetime) + f = select([func.length('abcdef')], bind=db).scalar() + f2 = select([func.length('abcdefghijk')], bind=db).scalar() # TODO: engine propigation across nested functions not working - currenttime = func.trunc(currenttime, literal_column("'DAY'"), bind=db) + currenttime = func.trunc(currenttime, literal_column("'DAY'"), bind=db, type_=Date) def1 = currenttime - def2 = func.trunc(text("sysdate"), literal_column("'DAY'")) + def2 = func.trunc(text("sysdate"), literal_column("'DAY'"), type_=Date) + deftype = Date elif use_function_defaults: - f = select([func.count(1) + 5], bind=db).scalar() - f2 = select([func.count(1) + 14], bind=db).scalar() + f = select([func.length('abcdef')], bind=db).scalar() + f2 = select([func.length('abcdefghijk')], bind=db).scalar() def1 = currenttime - def2 = text("current_date") + if testing.against('maxdb'): + def2 = text("curdate") + else: + def2 = text("current_date") deftype = Date ts = db.func.current_date().scalar() else: - f = select([func.count(1) + 5], bind=db).scalar() - f2 = select([func.count(1) + 14], bind=db).scalar() + f = select([func.length('abcdef')], bind=db).scalar() + f2 = select([func.length('abcdefghijk')], bind=db).scalar() def1 = def2 = "3" ts = 3 deftype = Integer - + t = Table('default_test1', metadata, # python function Column('col1', Integer, primary_key=True, default=mydefault), - + # python literal Column('col2', String(20), default="imthedefault", onupdate="im the update"), - + # preexecute expression - Column('col3', Integer, default=func.count(1) + 5, onupdate=func.count(1) + 14), - + Column('col3', Integer, default=func.length('abcdef'), onupdate=func.length('abcdefghijk')), + # SQL-side default from sql expression Column('col4', deftype, PassiveDefault(def1)), - + # SQL-side default from literal expression Column('col5', deftype, PassiveDefault(def2)), - + # preexecute + update timestamp Column('col6', Date, default=currenttime, onupdate=currenttime), - + Column('boolcol1', Boolean, default=True), Column('boolcol2', Boolean, default=False), - + # python function which uses ExecutionContext - Column('col7', Integer, default=mydefault_with_ctx, onupdate=myupdate_with_ctx), - + Column('col7', Integer, default=mydefault_using_connection, onupdate=myupdate_with_ctx), + # python builtin Column('col8', Date, default=datetime.date.today, onupdate=datetime.date.today) ) @@ -85,59 +99,138 @@ class DefaultTest(PersistTest): def tearDownAll(self): t.drop() - + def tearDown(self): + default_generator['x'] = 50 t.delete().execute() - - def testargsignature(self): - def mydefault(x, y): - pass - try: - c = ColumnDefault(mydefault) - assert False - except exceptions.ArgumentError, e: - assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e) - + + def test_bad_argsignature(self): + ex_msg = \ + "ColumnDefault Python function takes zero or one positional arguments" + + def fn1(x, y): pass + def fn2(x, y, z=3): pass + class fn3(object): + def __init__(self, x, y): + pass + class FN4(object): + def __call__(self, x, y): + pass + fn4 = FN4() + + for fn in fn1, fn2, fn3, fn4: + try: + c = ColumnDefault(fn) + assert False, str(fn) + except exceptions.ArgumentError, e: + assert str(e) == ex_msg + + def test_argsignature(self): + def fn1(): pass + def fn2(): pass + def fn3(x=1): pass + def fn4(x=1, y=2, z=3): pass + fn5 = list + class fn6(object): + def __init__(self, x): + pass + class fn6(object): + def __init__(self, x, y=3): + pass + class FN7(object): + def __call__(self, x): + pass + fn7 = FN7() + class FN8(object): + def __call__(self, x, y=3): + pass + fn8 = FN8() + + for fn in fn1, fn2, fn3, fn4, fn5, fn6, fn7, fn8: + c = ColumnDefault(fn) + def teststandalone(self): - c = testbase.db.engine.contextual_connect() + c = testing.db.engine.contextual_connect() x = c.execute(t.c.col1.default) y = t.c.col2.default.execute() z = c.execute(t.c.col3.default) self.assert_(50 <= x <= 57) self.assert_(y == 'imthedefault') self.assert_(z == f) - # mysql/other db's return 0 or 1 for count(1) - self.assert_(5 <= z <= 6) - + self.assert_(f2==11) + def testinsert(self): r = t.insert().execute() - self.assert_(r.lastrow_has_defaults()) + assert r.lastrow_has_defaults() + assert util.Set(r.context.postfetch_cols) == util.Set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]) + + r = t.insert(inline=True).execute() + assert r.lastrow_has_defaults() + assert util.Set(r.context.postfetch_cols) == util.Set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]) + t.insert().execute() t.insert().execute() + ctexec = select([currenttime.label('now')], bind=testing.db).scalar() + l = t.select().execute() + today = datetime.date.today() + self.assertEquals(l.fetchall(), [ + (51, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), + (52, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), + (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), + (54, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), + ]) + + def testinsertmany(self): + # MySQL-Python 1.2.2 breaks functions in execute_many :( + if (testing.against('mysql') and + testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): + return + + r = t.insert().execute({}, {}, {}) + ctexec = currenttime.scalar() - print "Currenttime "+ repr(ctexec) l = t.select().execute() today = datetime.date.today() - self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 61, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 62, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 63, today)]) + self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today)]) def testinsertvalues(self): t.insert(values={'col3':50}).execute() l = t.select().execute() self.assert_(l.fetchone()['col3'] == 50) - - + + def testupdatemany(self): + # MySQL-Python 1.2.2 breaks functions in execute_many :( + if (testing.against('mysql') and + testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): + return + + t.insert().execute({}, {}, {}) + + t.update(t.c.col1==bindparam('pkval')).execute( + {'pkval':51,'col7':None, 'col8':None, 'boolcol1':False}, + ) + + t.update(t.c.col1==bindparam('pkval')).execute( + {'pkval':51,}, + {'pkval':52,}, + {'pkval':53,}, + ) + + l = t.select().execute() + ctexec = currenttime.scalar() + today = datetime.date.today() + self.assert_(l.fetchall() == [(51, 'im the update', f2, ts, ts, ctexec, False, False, 13, today), (52, 'im the update', f2, ts, ts, ctexec, True, False, 13, today), (53, 'im the update', f2, ts, ts, ctexec, True, False, 13, today)]) + def testupdate(self): r = t.insert().execute() pk = r.last_inserted_ids()[0] t.update(t.c.col1==pk).execute(col4=None, col5=None) ctexec = currenttime.scalar() - print "Currenttime "+ repr(ctexec) l = t.select(t.c.col1==pk).execute() l = l.fetchone() self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today())) - # mysql/other db's return 0 or 1 for count(1) - self.assert_(14 <= f2 <= 15) + self.assert_(f2==11) def testupdatevalues(self): r = t.insert().execute() @@ -147,17 +240,17 @@ class DefaultTest(PersistTest): l = l.fetchone() self.assert_(l['col3'] == 55) - @testing.supported('postgres') + @testing.fails_on_everything_except('postgres') def testpassiveoverride(self): - """primarily for postgres, tests that when we get a primary key column back + """primarily for postgres, tests that when we get a primary key column back from reflecting a table which has a default value on it, we pre-execute - that PassiveDefault upon insert, even though PassiveDefault says + that PassiveDefault upon insert, even though PassiveDefault says "let the database execute this", because in postgres we must have all the primary key values in memory before insert; otherwise we cant locate the just inserted row.""" try: - meta = MetaData(testbase.db) - testbase.db.execute(""" + meta = MetaData(testing.db) + testing.db.execute(""" CREATE TABLE speedy_users ( speedy_user_id SERIAL PRIMARY KEY, @@ -172,112 +265,208 @@ class DefaultTest(PersistTest): l = t.select().execute().fetchall() self.assert_(l == [(1, 'user', 'lala')]) finally: - testbase.db.execute("drop table speedy_users", None) + testing.db.execute("drop table speedy_users", None) -class AutoIncrementTest(PersistTest): - @testing.supported('postgres', 'mysql') +class PKDefaultTest(TestBase): + def setUpAll(self): + global metadata, t1, t2 + + metadata = MetaData(testing.db) + + t2 = Table('t2', metadata, + Column('nextid', Integer)) + + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True, default=select([func.max(t2.c.nextid)]).as_scalar()), + Column('data', String(30))) + + metadata.create_all() + + def tearDownAll(self): + metadata.drop_all() + + @testing.unsupported('mssql') + def test_basic(self): + t2.insert().execute(nextid=1) + r = t1.insert().execute(data='hi') + assert r.last_inserted_ids() == [1] + + t2.insert().execute(nextid=2) + r = t1.insert().execute(data='there') + assert r.last_inserted_ids() == [2] + + +class AutoIncrementTest(TestBase): + def setUp(self): + global aitable, aimeta + + aimeta = MetaData(testing.db) + aitable = Table("aitest", aimeta, + Column('id', Integer, Sequence('ai_id_seq', optional=True), + primary_key=True), + Column('int1', Integer), + Column('str1', String(20))) + aimeta.create_all() + + def tearDown(self): + aimeta.drop_all() + + # should fail everywhere... was: @supported('postgres', 'mysql', 'maxdb') + @testing.fails_on('sqlite') def testnonautoincrement(self): - meta = MetaData(testbase.db) - nonai_table = Table("aitest", meta, + # sqlite INT primary keys can be non-unique! (only for ints) + meta = MetaData(testing.db) + nonai_table = Table("nonaitest", meta, Column('id', Integer, autoincrement=False, primary_key=True), Column('data', String(20))) - nonai_table.create() + nonai_table.create(checkfirst=True) try: try: - # postgres will fail on first row, mysql fails on second row + # postgres + mysql strict will fail on first row, + # mysql in legacy mode fails on second row nonai_table.insert().execute(data='row 1') nonai_table.insert().execute(data='row 2') assert False except exceptions.SQLError, e: print "Got exception", str(e) assert True - + nonai_table.insert().execute(id=1, data='row 1') finally: - nonai_table.drop() + nonai_table.drop() - def testwithautoincrement(self): - meta = MetaData(testbase.db) - table = Table("aitest", meta, - Column('id', Integer, primary_key=True), - Column('data', String(20))) - table.create() + # TODO: add coverage for increment on a secondary column in a key + def _test_autoincrement(self, bind): + ids = set() + rs = bind.execute(aitable.insert(), int1=1) + last = rs.last_inserted_ids()[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = bind.execute(aitable.insert(), str1='row 2') + last = rs.last_inserted_ids()[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = bind.execute(aitable.insert(), int1=3, str1='row 3') + last = rs.last_inserted_ids()[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = bind.execute(aitable.insert(values={'int1':func.length('four')})) + last = rs.last_inserted_ids()[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + self.assert_( + list(bind.execute(aitable.select().order_by(aitable.c.id))) == + [(1, 1, None), (2, None, 'row 2'), (3, 3, 'row 3'), (4, 4, None)]) + + def test_autoincrement_autocommit(self): + self._test_autoincrement(testing.db) + + def test_autoincrement_transaction(self): + con = testing.db.connect() + tx = con.begin() try: - table.insert().execute(data='row 1') - table.insert().execute(data='row 2') + try: + self._test_autoincrement(con) + except: + try: + tx.rollback() + except: + pass + raise + else: + tx.commit() finally: - table.drop() - - def testfetchid(self): - - # TODO: what does this test do that all the various ORM tests dont ? - - meta = MetaData(testbase.db) - table = Table("aitest", meta, - Column('id', Integer, primary_key=True), - Column('data', String(20))) - table.create() + con.close() + + def test_autoincrement_fk(self): + if not testing.db.dialect.supports_pk_autoincrement: + return True + + metadata = MetaData(testing.db) + # No optional sequence here. + nodes = Table('nodes', metadata, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('nodes.id')), + Column('data', String(30))) + metadata.create_all() try: - # simulate working on a table that doesn't already exist - meta2 = MetaData(testbase.db) - table2 = Table("aitest", meta2, - Column('id', Integer, primary_key=True), - Column('data', String(20))) - class AiTest(object): - pass - mapper(AiTest, table2) - - s = create_session() - u = AiTest() - s.save(u) - s.flush() - assert u.id is not None - s.clear() + r = nodes.insert().execute(data='foo') + id_ = r.last_inserted_ids()[0] + nodes.insert().execute(data='bar', parent_id=id_) finally: - table.drop() - + metadata.drop_all() + + +class SequenceTest(TestBase): + __unsupported_on__ = ('sqlite', 'mysql', 'mssql', 'firebird', + 'sybase', 'access') -class SequenceTest(PersistTest): - @testing.supported('postgres', 'oracle') def setUpAll(self): global cartitems, sometable, metadata - metadata = MetaData(testbase.db) - cartitems = Table("cartitems", metadata, + metadata = MetaData(testing.db) + cartitems = Table("cartitems", metadata, Column("cart_id", Integer, Sequence('cart_id_seq'), primary_key=True), Column("description", String(40)), Column("createdate", DateTime()) ) sometable = Table( 'Manager', metadata, - Column( 'obj_id', Integer, Sequence('obj_id_seq'), ), - Column( 'name', String, ), - Column( 'id', Integer, primary_key= True, ), + Column('obj_id', Integer, Sequence('obj_id_seq'), ), + Column('name', String(128)), + Column('id', Integer, Sequence('Manager_id_seq', optional=True), + primary_key=True), ) - + metadata.create_all() - - @testing.supported('postgres', 'oracle') + def testseqnonpk(self): """test sequences fire off as defaults on non-pk columns""" - sometable.insert().execute(name="somename") - sometable.insert().execute(name="someother") + + result = sometable.insert().execute(name="somename") + assert 'id' in result.postfetch_cols() + + result = sometable.insert().execute(name="someother") + assert 'id' in result.postfetch_cols() + + sometable.insert().execute( + {'name':'name3'}, + {'name':'name4'} + ) assert sometable.select().execute().fetchall() == [ (1, "somename", 1), (2, "someother", 2), + (3, "name3", 3), + (4, "name4", 4), ] - - @testing.supported('postgres', 'oracle') + def testsequence(self): cartitems.insert().execute(description='hi') cartitems.insert().execute(description='there') - cartitems.insert().execute(description='lala') - + r = cartitems.insert().execute(description='lala') + + assert r.last_inserted_ids() and r.last_inserted_ids()[0] is not None + id_ = r.last_inserted_ids()[0] + + assert select([func.count(cartitems.c.cart_id)], + and_(cartitems.c.description == 'lala', + cartitems.c.cart_id == id_)).scalar() == 1 + cartitems.select().execute().fetchall() - - - @testing.supported('postgres', 'oracle') + + + @testing.fails_on('maxdb') + # maxdb db-api seems to double-execute NEXTVAL internally somewhere, + # throwing off the numbers for these tests... def test_implicit_sequence_exec(self): - s = Sequence("my_sequence", metadata=MetaData(testbase.db)) + s = Sequence("my_sequence", metadata=MetaData(testing.db)) s.create() try: x = s.execute() @@ -285,32 +474,31 @@ class SequenceTest(PersistTest): finally: s.drop() - @testing.supported('postgres', 'oracle') + @testing.fails_on('maxdb') def teststandalone_explicit(self): s = Sequence("my_sequence") - s.create(bind=testbase.db) + s.create(bind=testing.db) try: - x = s.execute(testbase.db) + x = s.execute(testing.db) self.assert_(x == 1) finally: - s.drop(testbase.db) - - @testing.supported('postgres', 'oracle') + s.drop(testing.db) + def test_checkfirst(self): s = Sequence("my_sequence") - s.create(testbase.db, checkfirst=False) - s.create(testbase.db, checkfirst=True) - s.drop(testbase.db, checkfirst=False) - s.drop(testbase.db, checkfirst=True) - - @testing.supported('postgres', 'oracle') + s.create(testing.db, checkfirst=False) + s.create(testing.db, checkfirst=True) + s.drop(testing.db, checkfirst=False) + s.drop(testing.db, checkfirst=True) + + @testing.fails_on('maxdb') def teststandalone2(self): x = cartitems.c.cart_id.sequence.execute() self.assert_(1 <= x <= 4) - - @testing.supported('postgres', 'oracle') - def tearDownAll(self): + + def tearDownAll(self): metadata.drop_all() + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/sql/functions.py b/test/sql/functions.py new file mode 100644 index 0000000000..d1ce17c72f --- /dev/null +++ b/test/sql/functions.py @@ -0,0 +1,256 @@ +import testenv; testenv.configure_for_tests() +import datetime +from sqlalchemy import * +from sqlalchemy.sql import table, column +from sqlalchemy import databases, exceptions, sql, util +from sqlalchemy.sql.compiler import BIND_TEMPLATES +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes +from testlib import * + +from sqlalchemy.databases import * +# every dialect in databases.__all__ is expected to pass these tests. +dialects = [getattr(databases, mod).dialect() + for mod in databases.__all__ + # fixme! + if mod not in ('access',)] + +# if the configured dialect is out-of-tree or not yet in __all__, include it +# too. +if testing.db.name not in databases.__all__: + dialects.append(testing.db.dialect) + + +class CompileTest(TestBase, AssertsCompiledSQL): + def test_compile(self): + for dialect in dialects: + bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect) + self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect) + if isinstance(dialect, firebird.dialect): + self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect) + else: + self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect) + self.assert_compile(func.char_length('foo'), "char_length(%s)" % bindtemplate % {'name':'param_1', 'position':1}, dialect=dialect) + + def test_underscores(self): + self.assert_compile(func.if_(), "if()") + + def test_generic_now(self): + assert isinstance(func.now().type, sqltypes.DateTime) + + for ret, dialect in [ + ('CURRENT_TIMESTAMP', sqlite.dialect()), + ('now()', postgres.dialect()), + ('now()', mysql.dialect()), + ('CURRENT_TIMESTAMP', oracle.dialect()) + ]: + self.assert_compile(func.now(), ret, dialect=dialect) + + def test_generic_random(self): + assert func.random().type == sqltypes.NULLTYPE + assert isinstance(func.random(type_=Integer).type, Integer) + + for ret, dialect in [ + ('random()', sqlite.dialect()), + ('random()', postgres.dialect()), + ('rand()', mysql.dialect()), + ('random()', oracle.dialect()) + ]: + self.assert_compile(func.random(), ret, dialect=dialect) + + def test_constructor(self): + try: + func.current_timestamp('somearg') + assert False + except TypeError: + assert True + + try: + func.char_length('a', 'b') + assert False + except TypeError: + assert True + + try: + func.char_length() + assert False + except TypeError: + assert True + + def test_typing(self): + assert isinstance(func.coalesce(datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)).type, sqltypes.Date) + + assert isinstance(func.coalesce(None, datetime.date(2005, 10, 15)).type, sqltypes.Date) + + assert isinstance(func.concat("foo", "bar").type, sqltypes.String) + + def test_assorted(self): + table1 = table('mytable', + column('myid', Integer), + ) + + table2 = table( + 'myothertable', + column('otherid', Integer), + ) + + # test an expression with a function + self.assert_compile(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, + "lala(:lala_1, :lala_2, :param_1, mytable.myid) * myothertable.otherid") + + # test it in a SELECT + self.assert_compile(select([func.count(table1.c.myid)]), + "SELECT count(mytable.myid) AS count_1 FROM mytable") + + # test a "dotted" function name + self.assert_compile(select([func.foo.bar.lala(table1.c.myid)]), + "SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable") + + # test the bind parameter name with a "dotted" function name is only the name + # (limits the length of the bind param name) + self.assert_compile(select([func.foo.bar.lala(12)]), + "SELECT foo.bar.lala(:lala_2) AS lala_1") + + # test a dotted func off the engine itself + self.assert_compile(func.lala.hoho(7), "lala.hoho(:hoho_1)") + + # test None becomes NULL + self.assert_compile(func.my_func(1,2,None,3), "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") + + # test pickling + self.assert_compile(util.pickle.loads(util.pickle.dumps(func.my_func(1, 2, None, 3))), "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") + + # assert func raises AttributeError for __bases__ attribute, since its not a class + # fixes pydoc + try: + func.__bases__ + assert False + except AttributeError: + assert True + + def test_functions_with_cols(self): + users = table('users', column('id'), column('name'), column('fullname')) + calculate = select([column('q'), column('z'), column('r')], + from_obj=[func.calculate(bindparam('x'), bindparam('y'))]) + + self.assert_compile(select([users], users.c.id > calculate.c.z), + "SELECT users.id, users.name, users.fullname " + "FROM users, (SELECT q, z, r " + "FROM calculate(:x, :y)) " + "WHERE users.id > z" + ) + + s = select([users], users.c.id.between( + calculate.alias('c1').unique_params(x=17, y=45).c.z, + calculate.alias('c2').unique_params(x=5, y=12).c.z)) + + self.assert_compile(s, + "SELECT users.id, users.name, users.fullname " + "FROM users, (SELECT q, z, r " + "FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r " + "FROM calculate(:x_2, :y_2)) AS c2 " + "WHERE users.id BETWEEN c1.z AND c2.z" + , checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5}) + + +class ExecuteTest(TestBase): + + def test_standalone_execute(self): + x = testing.db.func.current_date().execute().scalar() + y = testing.db.func.current_date().select().execute().scalar() + z = testing.db.func.current_date().scalar() + assert (x == y == z) is True + + # ansi func + x = testing.db.func.current_date() + assert isinstance(x.type, Date) + assert isinstance(x.execute().scalar(), datetime.date) + + def test_conn_execute(self): + conn = testing.db.connect() + try: + x = conn.execute(func.current_date()).scalar() + y = conn.execute(func.current_date().select()).scalar() + z = conn.scalar(func.current_date()) + finally: + conn.close() + assert (x == y == z) is True + + def test_update(self): + """ + Tests sending functions and SQL expressions to the VALUES and SET + clauses of INSERT/UPDATE instances, and that column-level defaults + get overridden. + """ + + meta = MetaData(testing.db) + t = Table('t1', meta, + Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True), + Column('value', Integer) + ) + t2 = Table('t2', meta, + Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True), + Column('value', Integer, default=7), + Column('stuff', String(20), onupdate="thisisstuff") + ) + meta.create_all() + try: + t.insert(values=dict(value=func.length("one"))).execute() + assert t.select().execute().fetchone()['value'] == 3 + t.update(values=dict(value=func.length("asfda"))).execute() + assert t.select().execute().fetchone()['value'] == 5 + + r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() + id = r.last_inserted_ids()[0] + assert t.select(t.c.id==id).execute().fetchone()['value'] == 9 + t.update(values={t.c.value:func.length("asdf")}).execute() + assert t.select().execute().fetchone()['value'] == 4 + print "--------------------------" + t2.insert().execute() + t2.insert(values=dict(value=func.length("one"))).execute() + t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi") + + res = exec_sorted(select([t2.c.value, t2.c.stuff])) + self.assertEquals(res, [(-14, 'hi'), (3, None), (7, None)]) + + t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff") + assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] + + t2.delete().execute() + + t2.insert(values=dict(value=func.length("one") + 8)).execute() + assert t2.select().execute().fetchone()['value'] == 11 + + t2.update(values=dict(value=func.length("asfda"))).execute() + assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (5, "thisisstuff") + + t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute() + print "HI", select([t2.c.value, t2.c.stuff]).execute().fetchone() + assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (9, "foo") + finally: + meta.drop_all() + + @testing.fails_on_everything_except('postgres') + def test_as_from(self): + # TODO: shouldnt this work on oracle too ? + x = testing.db.func.current_date().execute().scalar() + y = testing.db.func.current_date().select().execute().scalar() + z = testing.db.func.current_date().scalar() + w = select(['*'], from_obj=[testing.db.func.current_date()]).scalar() + + # construct a column-based FROM object out of a function, like in [ticket:172] + s = select([sql.column('date', type_=DateTime)], from_obj=[testing.db.func.current_date()]) + q = s.execute().fetchone()[s.c.date] + r = s.alias('datequery').select().scalar() + + assert x == y == z == w == q == r + +def exec_sorted(statement, *args, **kw): + """Executes a statement and returns a sorted list plain tuple rows.""" + + return sorted([tuple(row) + for row in statement.execute(*args, **kw).fetchall()]) + +if __name__ == '__main__': + testenv.main() diff --git a/test/sql/generative.py b/test/sql/generative.py index 357a66fcdf..8204742821 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -1,15 +1,20 @@ -import testbase -from sql import select as selecttests +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy.sql import table, column, ClauseElement +from sqlalchemy.sql.expression import _clone from testlib import * +from sqlalchemy.sql.visitors import * +from sqlalchemy import util +from sqlalchemy.sql import util as sql_util -class TraversalTest(AssertMixin): + +class TraversalTest(TestBase, AssertsExecutionResults): """test ClauseVisitor's traversal, particularly its ability to copy and modify a ClauseElement in place.""" - + def setUpAll(self): global A, B - + # establish two ficticious ClauseElements. # define deep equality semantics as well as deep identity semantics. class A(ClauseElement): @@ -18,16 +23,16 @@ class TraversalTest(AssertMixin): def is_other(self, other): return other is self - + def __eq__(self, other): return other.expr == self.expr - + def __ne__(self, other): return other.expr != self.expr - + def __str__(self): return "A(%s)" % repr(self.expr) - + class B(ClauseElement): def __init__(self, *items): self.items = items @@ -45,22 +50,22 @@ class TraversalTest(AssertMixin): if i1 != i2: return False return True - + def __ne__(self, other): for i1, i2 in zip(self.items, other.items): if i1 != i2: return True return False - - def _copy_internals(self): - self.items = [i._clone() for i in self.items] + + def _copy_internals(self, clone=_clone): + self.items = [clone(i) for i in self.items] def get_children(self, **kwargs): return self.items - + def __str__(self): return "B(%s)" % repr([str(i) for i in self.items]) - + def test_test_classes(self): a1 = A("expr1") struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) @@ -73,22 +78,22 @@ class TraversalTest(AssertMixin): assert struct != struct3 assert not struct.is_other(struct2) assert not struct.is_other(struct3) - - def test_clone(self): + + def test_clone(self): struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) - + class Vis(ClauseVisitor): def visit_a(self, a): pass def visit_b(self, b): pass - + vis = Vis() s2 = vis.traverse(struct, clone=True) assert struct == s2 assert not struct.is_other(s2) - - def test_no_clone(self): + + def test_no_clone(self): struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) class Vis(ClauseVisitor): @@ -101,7 +106,7 @@ class TraversalTest(AssertMixin): s2 = vis.traverse(struct, clone=False) assert struct == s2 assert struct.is_other(s2) - + def test_change_in_place(self): struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3")) @@ -132,44 +137,72 @@ class TraversalTest(AssertMixin): assert struct != s3 assert struct3 == s3 -class ClauseTest(selecttests.SQLTest): + +class ClauseTest(TestBase, AssertsCompiledSQL): """test copy-in-place behavior of various ClauseElements.""" - + def setUpAll(self): global t1, t2 - t1 = table("table1", + t1 = table("table1", column("col1"), column("col2"), column("col3"), ) - t2 = table("table2", + t2 = table("table2", column("col1"), column("col2"), column("col3"), ) - + def test_binary(self): clause = t1.c.col2 == t2.c.col2 assert str(clause) == ClauseVisitor().traverse(clause, clone=True) - + + def test_binary_anon_label_quirk(self): + t = table('t1', column('col1')) + + + f = t.c.col1 * 5 + self.assert_compile(select([f]), "SELECT t1.col1 * :col1_1 AS anon_1 FROM t1") + + f.anon_label + + a = t.alias() + f = sql_util.ClauseAdapter(a).traverse(f) + + self.assert_compile(select([f]), "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1") + def test_join(self): clause = t1.join(t2, t1.c.col2==t2.c.col2) c1 = str(clause) assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True)) - + class Vis(ClauseVisitor): def visit_binary(self, binary): binary.right = t2.c.col3 - + clause2 = Vis().traverse(clause, clone=True) assert c1 == str(clause) assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3)) + def test_text(self): + clause = text("select * from table where foo=:bar", bindparams=[bindparam('bar')]) + c1 = str(clause) + class Vis(ClauseVisitor): + def visit_textclause(self, text): + text.text = text.text + " SOME MODIFIER=:lala" + text.bindparams['lala'] = bindparam('lala') + + clause2 = Vis().traverse(clause, clone=True) + assert c1 == str(clause) + assert str(clause2) == c1 + " SOME MODIFIER=:lala" + assert clause.bindparams.keys() == ['bar'] + assert util.Set(clause2.bindparams.keys()) == util.Set(['bar', 'lala']) + def test_select(self): - s = t1.select() - s2 = select([s]) + s2 = select([t1]) s2_assert = str(s2) - s3_assert = str(select([t1.select()], t1.c.col2==7)) + s3_assert = str(select([t1], t1.c.col2==7)) class Vis(ClauseVisitor): def visit_select(self, select): select.append_whereclause(t1.c.col2==7) @@ -182,8 +215,8 @@ class ClauseTest(selecttests.SQLTest): assert str(s2) == s3_assert print "------------------" - - s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9))) + + s4_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col3==9))) class Vis(ClauseVisitor): def visit_select(self, select): select.append_whereclause(t1.c.col3==9) @@ -192,84 +225,420 @@ class ClauseTest(selecttests.SQLTest): print str(s4) assert str(s4) == s4_assert assert str(s3) == s3_assert - + print "------------------" - s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9))) + s5_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col1==9))) class Vis(ClauseVisitor): def visit_binary(self, binary): if binary.left is t1.c.col3: binary.left = t1.c.col1 - binary.right = bindparam("table1_col1") + binary.right = bindparam("col1", unique=True) s5 = Vis().traverse(s4, clone=True) print str(s4) print str(s5) assert str(s5) == s5_assert assert str(s4) == s4_assert + def test_union(self): + u = union(t1.select(), t2.select()) + u2 = ClauseVisitor().traverse(u, clone=True) + assert str(u) == str(u2) + assert [str(c) for c in u2.c] == [str(c) for c in u.c] + + u = union(t1.select(), t2.select()) + cols = [str(c) for c in u.c] + u2 = ClauseVisitor().traverse(u, clone=True) + assert str(u) == str(u2) + assert [str(c) for c in u2.c] == cols + + s1 = select([t1], t1.c.col1 == bindparam('id_param')) + s2 = select([t2]) + u = union(s1, s2) + + u2 = u.params(id_param=7) + u3 = u.params(id_param=10) + assert str(u) == str(u2) == str(u3) + assert u2.compile().params == {'id_param':7} + assert u3.compile().params == {'id_param':10} + + def test_binds(self): + """test that unique bindparams change their name upon clone() to prevent conflicts""" + + s = select([t1], t1.c.col1==bindparam(None, unique=True)).alias() + s2 = ClauseVisitor().traverse(s, clone=True).alias() + s3 = select([s], s.c.col2==s2.c.col2) + + self.assert_compile(s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "\ + "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :param_1) AS anon_1, "\ + "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :param_2) AS anon_2 "\ + "WHERE anon_1.col2 = anon_2.col2") + + s = select([t1], t1.c.col1==4).alias() + s2 = ClauseVisitor().traverse(s, clone=True).alias() + s3 = select([s], s.c.col2==s2.c.col2) + self.assert_compile(s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "\ + "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1, "\ + "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_2) AS anon_2 "\ + "WHERE anon_1.col2 = anon_2.col2") + + @testing.emits_warning('.*replaced by another column with the same key') + def test_alias(self): + subq = t2.select().alias('subq') + s = select([t1.c.col1, subq.c.col1], from_obj=[t1, subq, t1.join(subq, t1.c.col1==subq.c.col2)]) + orig = str(s) + s2 = ClauseVisitor().traverse(s, clone=True) + assert orig == str(s) == str(s2) + + s4 = ClauseVisitor().traverse(s2, clone=True) + assert orig == str(s) == str(s2) == str(s4) + + s3 = sql_util.ClauseAdapter(table('foo')).traverse(s, clone=True) + assert orig == str(s) == str(s3) + + s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3, clone=True) + assert orig == str(s) == str(s3) == str(s4) + def test_correlated_select(self): s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2) class Vis(ClauseVisitor): def visit_select(self, select): select.append_whereclause(t1.c.col2==7) - - self.runtest(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2") - def test_clause_adapter(self): - from sqlalchemy import sql_util - + self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :col2_1") + +class ClauseAdapterTest(TestBase, AssertsCompiledSQL): + def setUpAll(self): + global t1, t2 + t1 = table("table1", + column("col1"), + column("col2"), + column("col3"), + ) + t2 = table("table2", + column("col1"), + column("col2"), + column("col3"), + ) + + def test_correlation_on_clone(self): t1alias = t1.alias('t1alias') + t2alias = t2.alias('t2alias') + vis = sql_util.ClauseAdapter(t1alias) + + s = select(['*'], from_obj=[t1alias, t2alias]).as_scalar() + assert t2alias in s._froms + assert t1alias in s._froms + + self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)") + s = vis.traverse(s, clone=True) + assert t2alias not in s._froms # not present because it's been cloned + assert t1alias in s._froms # present because the adapter placed it there + # correlate list on "s" needs to take into account the full _cloned_set for each element in _froms when correlating + self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)") + + s = select(['*'], from_obj=[t1alias, t2alias]).correlate(t2alias).as_scalar() + self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)") + s = vis.traverse(s, clone=True) + self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)") + s = ClauseVisitor().traverse(s, clone=True) + self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)") + s = select(['*']).where(t1.c.col1==t2.c.col1).as_scalar() + self.assert_compile(select([t1.c.col1, s]), "SELECT table1.col1, (SELECT * FROM table2 WHERE table1.col1 = table2.col1) AS anon_1 FROM table1") vis = sql_util.ClauseAdapter(t1alias) - ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True) - assert ff._get_from_objects() == [t1alias] + s = vis.traverse(s, clone=True) + self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias") + s = ClauseVisitor().traverse(s, clone=True) + self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias") + + s = select(['*']).where(t1.c.col1==t2.c.col1).correlate(t1).as_scalar() + self.assert_compile(select([t1.c.col1, s]), "SELECT table1.col1, (SELECT * FROM table2 WHERE table1.col1 = table2.col1) AS anon_1 FROM table1") + vis = sql_util.ClauseAdapter(t1alias) + s = vis.traverse(s, clone=True) + self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias") + s = ClauseVisitor().traverse(s, clone=True) + self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias") - self.runtest(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias") - self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2") - self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2") - self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2") - self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2") + def test_table_to_alias(self): + + t1alias = t1.alias('t1alias') + + vis = sql_util.ClauseAdapter(t1alias) ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True) - self.runtest(ff, "count(t1alias.col1) AS foo") assert ff._get_from_objects() == [t1alias] - + + self.assert_compile(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias") + self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2") + self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2") + self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2") + self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2") + + + s = select(['*'], from_obj=[t1]).alias('foo') + self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo") + self.assert_compile(vis.traverse(s.select(), clone=True), "SELECT foo.* FROM (SELECT * FROM table1 AS t1alias) AS foo") + self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo") + + ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True) + self.assert_compile(ff, "count(t1alias.col1) AS foo") + assert ff._get_from_objects() == [t1alias] + # TODO: -# self.runtest(vis.traverse(select([func.count(t1.c.col1).label('foo')]), clone=True), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias") - + # self.assert_compile(vis.traverse(select([func.count(t1.c.col1).label('foo')]), clone=True), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias") + t2alias = t2.alias('t2alias') vis.chain(sql_util.ClauseAdapter(t2alias)) - self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") - self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") - self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") - self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2") - - + self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2") + + def test_include_exclude(self): + m = MetaData() + a=Table( 'a',m, + Column( 'id', Integer, primary_key=True), + Column( 'xxx_id', Integer, ForeignKey( 'a.id', name='adf',use_alter=True ) ) + ) + + e = (a.c.id == a.c.xxx_id) + assert str(e) == "a.id = a.xxx_id" + b = a.alias() + + e = sql_util.ClauseAdapter( b, include= set([ a.c.id ]), + equivalents= { a.c.id: set([ a.c.id]) } + ).traverse( e) + + assert str(e) == "a_1.id = a.xxx_id" + + def test_join_to_alias(self): + metadata = MetaData() + a = Table('a', metadata, + Column('id', Integer, primary_key=True)) + b = Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + c = Table('c', metadata, + Column('id', Integer, primary_key=True), + Column('bid', Integer, ForeignKey('b.id')), + ) + + d = Table('d', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + + j1 = a.outerjoin(b) + j2 = select([j1], use_labels=True) + + j3 = c.join(j2, j2.c.b_id==c.c.bid) + + j4 = j3.outerjoin(d) + self.assert_compile(j4, "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) " + "ON b_id = c.bid" + " LEFT OUTER JOIN d ON a_id = d.aid") + j5 = j3.alias('foo') + j6 = sql_util.ClauseAdapter(j5).copy_and_process([j4])[0] + + # this statement takes c join(a join b), wraps it inside an aliased "select * from c join(a join b) AS foo". + # the outermost right side "left outer join d" stays the same, except "d" joins against foo.a_id instead + # of plain "a_id" + self.assert_compile(j6, "(SELECT c.id AS c_id, c.bid AS c_bid, a_id AS a_id, b_id AS b_id, b_aid AS b_aid FROM " + "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) " + "ON b_id = c.bid) AS foo" + " LEFT OUTER JOIN d ON foo.a_id = d.aid") + + def test_derived_from(self): + assert select([t1]).is_derived_from(t1) + assert not select([t2]).is_derived_from(t1) + assert not t1.is_derived_from(select([t1])) + assert t1.alias().is_derived_from(t1) + + + s1 = select([t1, t2]).alias('foo') + s2 = select([s1]).limit(5).offset(10).alias() + assert s2.is_derived_from(s1) + s2 = s2._clone() + assert s2.is_derived_from(s1) + + def test_aliasedselect_to_aliasedselect(self): + # original issue from ticket #904 + s1 = select([t1]).alias('foo') + s2 = select([s1]).limit(5).offset(10).alias() + + self.assert_compile(sql_util.ClauseAdapter(s2).traverse(s1), + "SELECT foo.col1, foo.col2, foo.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo LIMIT 5 OFFSET 10") + + j = s1.outerjoin(t2, s1.c.col1==t2.c.col1) + self.assert_compile(sql_util.ClauseAdapter(s2).traverse(j).select(), + "SELECT anon_1.col1, anon_1.col2, anon_1.col3, table2.col1, table2.col2, table2.col3 FROM "\ + "(SELECT foo.col1 AS col1, foo.col2 AS col2, foo.col3 AS col3 FROM "\ + "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo LIMIT 5 OFFSET 10) AS anon_1 "\ + "LEFT OUTER JOIN table2 ON anon_1.col1 = table2.col1") + + talias = t1.alias('bar') + j = s1.outerjoin(talias, s1.c.col1==talias.c.col1) + self.assert_compile(sql_util.ClauseAdapter(s2).traverse(j).select(), + "SELECT anon_1.col1, anon_1.col2, anon_1.col3, bar.col1, bar.col2, bar.col3 FROM "\ + "(SELECT foo.col1 AS col1, foo.col2 AS col2, foo.col3 AS col3 FROM "\ + "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo LIMIT 5 OFFSET 10) AS anon_1 "\ + "LEFT OUTER JOIN table1 AS bar ON anon_1.col1 = bar.col1") + + def test_recursive(self): + metadata = MetaData() + a = Table('a', metadata, + Column('id', Integer, primary_key=True)) + b = Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + c = Table('c', metadata, + Column('id', Integer, primary_key=True), + Column('bid', Integer, ForeignKey('b.id')), + ) + + d = Table('d', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + + u = union( + a.join(b).select().apply_labels(), + a.join(d).select().apply_labels() + ).alias() -class SelectTest(selecttests.SQLTest): + self.assert_compile( + sql_util.ClauseAdapter(u).traverse(select([c.c.bid]).where(c.c.bid==u.c.b_aid)), + "SELECT c.bid "\ + "FROM c, (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid "\ + "FROM a JOIN b ON a.id = b.aid UNION SELECT a.id AS a_id, d.id AS d_id, d.aid AS d_aid "\ + "FROM a JOIN d ON a.id = d.aid) AS anon_1 "\ + "WHERE c.bid = anon_1.b_aid" + ) + +class SelectTest(TestBase, AssertsCompiledSQL): """tests the generative capability of Select""" def setUpAll(self): global t1, t2 - t1 = table("table1", + t1 = table("table1", column("col1"), column("col2"), column("col3"), ) - t2 = table("table2", + t2 = table("table2", column("col1"), column("col2"), column("col3"), ) - + def test_select(self): - self.runtest(t1.select().where(t1.c.col1==5).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1 WHERE table1.col1 = :table1_col1 ORDER BY table1.col3") - - self.runtest(t1.select().select_from(select([t2], t2.c.col1==t1.c.col1)).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1) ORDER BY table1.col3") - + self.assert_compile(t1.select().where(t1.c.col1==5).order_by(t1.c.col3), + "SELECT table1.col1, table1.col2, table1.col3 FROM table1 WHERE table1.col1 = :col1_1 ORDER BY table1.col3") + + self.assert_compile(t1.select().select_from(select([t2], t2.c.col1==t1.c.col1)).order_by(t1.c.col3), + "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 "\ + "FROM table2 WHERE table2.col1 = table1.col1) ORDER BY table1.col3") + s = select([t2], t2.c.col1==t1.c.col1, correlate=False) s = s.correlate(t1).order_by(t2.c.col3) - self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3") + self.assert_compile(t1.select().select_from(s).order_by(t1.c.col3), + "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 "\ + "FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3") + + def test_columns(self): + s = t1.select() + self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1") + select_copy = s.column('yyy') + self.assert_compile(select_copy, "SELECT table1.col1, table1.col2, table1.col3, yyy FROM table1") + assert s.columns is not select_copy.columns + assert s._columns is not select_copy._columns + assert s._raw_columns is not select_copy._raw_columns + self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1") + + def test_froms(self): + s = t1.select() + self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1") + select_copy = s.select_from(t2) + self.assert_compile(select_copy, "SELECT table1.col1, table1.col2, table1.col3 FROM table1, table2") + assert s._froms is not select_copy._froms + self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1") + + def test_correlation(self): + s = select([t2], t1.c.col1==t2.c.col1) + self.assert_compile(s, "SELECT table2.col1, table2.col2, table2.col3 FROM table2, table1 WHERE table1.col1 = table2.col1") + s2 = select([t1], t1.c.col2==s.c.col2) + self.assert_compile(s2, "SELECT table1.col1, table1.col2, table1.col3 FROM table1, " + "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 " + "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2") + + s3 = s.correlate(None) + self.assert_compile(select([t1], t1.c.col2==s3.c.col2), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, " + "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2, table1 " + "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2") + self.assert_compile(select([t1], t1.c.col2==s.c.col2), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, " + "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 " + "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2") + s4 = s3.correlate(t1) + self.assert_compile(select([t1], t1.c.col2==s4.c.col2), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, " + "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 " + "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2") + self.assert_compile(select([t1], t1.c.col2==s3.c.col2), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, " + "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2, table1 " + "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2") + + def test_prefixes(self): + s = t1.select() + self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1") + select_copy = s.prefix_with("FOOBER") + self.assert_compile(select_copy, "SELECT FOOBER table1.col1, table1.col2, table1.col3 FROM table1") + self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1") + + +class InsertTest(TestBase, AssertsCompiledSQL): + """Tests the generative capability of Insert""" + + # fixme: consolidate converage from elsewhere here and expand + + def setUpAll(self): + global t1, t2 + t1 = table("table1", + column("col1"), + column("col2"), + column("col3"), + ) + t2 = table("table2", + column("col1"), + column("col2"), + column("col3"), + ) + + def test_prefixes(self): + i = t1.insert() + self.assert_compile(i, + "INSERT INTO table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)") + + gen = i.prefix_with("foober") + self.assert_compile(gen, + "INSERT foober INTO table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)") + + self.assert_compile(i, + "INSERT INTO table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)") + + i2 = t1.insert(prefixes=['squiznart']) + self.assert_compile(i2, + "INSERT squiznart INTO table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)") + gen2 = i2.prefix_with("quux") + self.assert_compile(gen2, + "INSERT squiznart quux INTO " + "table1 (col1, col2, col3) " + "VALUES (:col1, :col2, :col3)") if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/sql/labels.py b/test/sql/labels.py index 553a3a3bc3..cbcd4636eb 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -1,48 +1,51 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from testlib import * - +from sqlalchemy.engine import default # TODO: either create a mock dialect with named paramstyle and a short identifier length, # or find a way to just use sqlite dialect and make those changes -class LabelTypeTest(PersistTest): +IDENT_LENGTH = 29 + +class LabelTypeTest(TestBase): def test_type(self): m = MetaData() - t = Table('sometable', m, + t = Table('sometable', m, Column('col1', Integer), Column('col2', Float)) assert isinstance(t.c.col1.label('hi').type, Integer) - assert isinstance(select([t.c.col2], scalar=True).label('lala').type, Float) + assert isinstance(select([t.c.col2]).as_scalar().label('lala').type, Float) -class LongLabelsTest(PersistTest): +class LongLabelsTest(TestBase, AssertsCompiledSQL): def setUpAll(self): global metadata, table1, maxlen - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table1 = Table("some_large_named_table", metadata, Column("this_is_the_primarykey_column", Integer, Sequence("this_is_some_large_seq"), primary_key=True), Column("this_is_the_data_column", String(30)) ) - + metadata.create_all() - - maxlen = testbase.db.dialect.max_identifier_length - testbase.db.dialect.max_identifier_length = lambda: 29 - + + maxlen = testing.db.dialect.max_identifier_length + testing.db.dialect.max_identifier_length = IDENT_LENGTH + def tearDown(self): table1.delete().execute() - + def tearDownAll(self): metadata.drop_all() - testbase.db.dialect.max_identifier_length = maxlen - + testing.db.dialect.max_identifier_length = maxlen + def test_result(self): table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"}) table1.insert().execute(**{"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"}) table1.insert().execute(**{"this_is_the_primarykey_column":3, "this_is_the_data_column":"data3"}) table1.insert().execute(**{"this_is_the_primarykey_column":4, "this_is_the_data_column":"data4"}) - r = table1.select(use_labels=True, order_by=[table1.c.this_is_the_primarykey_column]).execute() + s = table1.select(use_labels=True, order_by=[table1.c.this_is_the_primarykey_column]) + r = s.execute() result = [] for row in r: result.append((row[table1.c.this_is_the_primarykey_column], row[table1.c.this_is_the_data_column])) @@ -52,7 +55,29 @@ class LongLabelsTest(PersistTest): (3, "data3"), (4, "data4"), ], repr(result) - + + # some dialects such as oracle (and possibly ms-sql in a future version) + # generate a subquery for limits/offsets. + # ensure that the generated result map corresponds to the selected table, not + # the select query + r = s.limit(2).execute() + result = [] + for row in r: + result.append((row[table1.c.this_is_the_primarykey_column], row[table1.c.this_is_the_data_column])) + assert result == [ + (1, "data1"), + (2, "data2"), + ], repr(result) + + r = s.limit(2).offset(1).execute() + result = [] + for row in r: + result.append((row[table1.c.this_is_the_primarykey_column], row[table1.c.this_is_the_data_column])) + assert result == [ + (2, "data2"), + (3, "data3"), + ], repr(result) + def test_colbinds(self): table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"}) table1.insert().execute(**{"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"}) @@ -67,29 +92,43 @@ class LongLabelsTest(PersistTest): table1.c.this_is_the_primarykey_column == 2 )).execute() assert r.fetchall() == [(2, "data2"), (4, "data4")] - + def test_insert_no_pk(self): table1.insert().execute(**{"this_is_the_data_column":"data1"}) table1.insert().execute(**{"this_is_the_data_column":"data2"}) table1.insert().execute(**{"this_is_the_data_column":"data3"}) table1.insert().execute(**{"this_is_the_data_column":"data4"}) - + def test_subquery(self): - # this is the test that fails if the "max identifier length" is shorter than the + # this is the test that fails if the "max identifier length" is shorter than the # length of the actual columns created, because the column names get truncated. # if you try to separate "physical columns" from "labels", and only truncate the labels, - # the ansisql.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code, + # the compiler.DefaultCompiler.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code, # since it is creating "labels" on the fly but not affecting derived columns, which think they are # still "physical" q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias('foo') x = select([q]) print x.execute().fetchall() - + + def test_anon_alias(self): + compile_dialect = default.DefaultDialect() + compile_dialect.max_identifier_length = IDENT_LENGTH + + q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias() + x = select([q], use_labels=True) + + self.assert_compile(x, "SELECT anon_1.this_is_the_primarykey_column AS anon_1_this_is_the_prim_1, anon_1.this_is_the_data_column AS anon_1_this_is_the_data_2 " + "FROM (SELECT some_large_named_table.this_is_the_primarykey_column AS this_is_the_primarykey_column, some_large_named_table.this_is_the_data_column AS this_is_the_data_column " + "FROM some_large_named_table " + "WHERE some_large_named_table.this_is_the_primarykey_column = :this_is_the_primarykey__1) AS anon_1", dialect=compile_dialect) + + print x.execute().fetchall() + def test_oid(self): """test that a primary key column compiled as the 'oid' column gets proper length truncation""" from sqlalchemy.databases import postgres dialect = postgres.PGDialect() - dialect.max_identifier_length = lambda: 30 + dialect.max_identifier_length = 30 tt = table1.select(use_labels=True).alias('foo') x = select([tt], use_labels=True, order_by=tt.oid_column).compile(dialect=dialect) #print x @@ -97,4 +136,4 @@ class LongLabelsTest(PersistTest): assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_2""") if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/sql/query.py b/test/sql/query.py index 48a28a9a56..e6d6714c2c 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -1,38 +1,46 @@ -import testbase +import testenv; testenv.configure_for_tests() import datetime from sqlalchemy import * -from sqlalchemy import exceptions +from sqlalchemy import exceptions, sql +from sqlalchemy.engine import default from testlib import * -class QueryTest(PersistTest): - +class QueryTest(TestBase): + def setUpAll(self): global users, addresses, metadata - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) users = Table('query_users', metadata, Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), ) - addresses = Table('query_addresses', metadata, + addresses = Table('query_addresses', metadata, Column('address_id', Integer, primary_key=True), Column('user_id', Integer, ForeignKey('query_users.user_id')), Column('address', String(30))) metadata.create_all() - + def tearDown(self): addresses.delete().execute() users.delete().execute() - + def tearDownAll(self): metadata.drop_all() - - def testinsert(self): + + def test_insert(self): users.insert().execute(user_id = 7, user_name = 'jack') assert users.count().scalar() == 1 - - def testupdate(self): + def test_insert_heterogeneous_params(self): + users.insert().execute( + {'user_id':7, 'user_name':'jack'}, + {'user_id':8, 'user_name':'ed'}, + {'user_id':9} + ) + assert users.select().execute().fetchall() == [(7, 'jack'), (8, 'ed'), (9, None)] + + def test_update(self): users.insert().execute(user_id = 7, user_name = 'jack') assert users.count().scalar() == 1 @@ -40,15 +48,15 @@ class QueryTest(PersistTest): assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred' def test_lastrow_accessor(self): - """test the last_inserted_ids() and lastrow_has_id() functions""" + """Tests the last_inserted_ids() and lastrow_has_id() functions.""" def insert_values(table, values): - """insert a row into a table, return the full list of values INSERTed including defaults - that fired off on the DB side. - + """ + Inserts a row into a table, returns the full list of values + INSERTed including defaults that fired off on the DB side and detects rows that had defaults and post-fetches. """ - + result = table.insert().execute(**values) ret = values.copy() @@ -65,7 +73,7 @@ class QueryTest(PersistTest): for supported, table, values, assertvalues in [ ( {'unsupported':['sqlite']}, - Table("t1", metadata, + Table("t1", metadata, Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True), Column('foo', String(30), primary_key=True)), {'foo':'hi'}, @@ -73,7 +81,7 @@ class QueryTest(PersistTest): ), ( {'unsupported':['sqlite']}, - Table("t2", metadata, + Table("t2", metadata, Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True), Column('foo', String(30), primary_key=True), Column('bar', String(30), PassiveDefault('hi')) @@ -83,7 +91,7 @@ class QueryTest(PersistTest): ), ( {'unsupported':[]}, - Table("t3", metadata, + Table("t3", metadata, Column("id", String(40), primary_key=True), Column('foo', String(30), primary_key=True), Column("bar", String(30)) @@ -93,7 +101,7 @@ class QueryTest(PersistTest): ), ( {'unsupported':[]}, - Table("t4", metadata, + Table("t4", metadata, Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True), Column('foo', String(30), primary_key=True), Column('bar', String(30), PassiveDefault('hi')) @@ -103,7 +111,7 @@ class QueryTest(PersistTest): ), ( {'unsupported':[]}, - Table("t5", metadata, + Table("t5", metadata, Column('id', String(10), primary_key=True), Column('bar', String(30), PassiveDefault('hi')) ), @@ -111,51 +119,121 @@ class QueryTest(PersistTest): {'id':'id1', 'bar':'hi'}, ), ]: - if testbase.db.name in supported['unsupported']: + if testing.db.name in supported['unsupported']: continue try: table.create() - assert insert_values(table, values) == assertvalues, repr(values) + " " + repr(assertvalues) + i = insert_values(table, values) + assert i == assertvalues, repr(i) + " " + repr(assertvalues) finally: table.drop() - def testrowiteration(self): - users.insert().execute(user_id = 7, user_name = 'jack') - users.insert().execute(user_id = 8, user_name = 'ed') - users.insert().execute(user_id = 9, user_name = 'fred') + def test_row_iteration(self): + users.insert().execute( + {'user_id':7, 'user_name':'jack'}, + {'user_id':8, 'user_name':'ed'}, + {'user_id':9, 'user_name':'fred'}, + ) r = users.select().execute() l = [] for row in r: l.append(row) self.assert_(len(l) == 3) + def test_anonymous_rows(self): + users.insert().execute( + {'user_id':7, 'user_name':'jack'}, + {'user_id':8, 'user_name':'ed'}, + {'user_id':9, 'user_name':'fred'}, + ) + + sel = select([users.c.user_id]).where(users.c.user_name=='jack').as_scalar() + for row in select([sel + 1, sel + 3], bind=users.bind).execute(): + assert row['anon_1'] == 8 + assert row['anon_2'] == 10 + + def test_row_comparison(self): + users.insert().execute(user_id = 7, user_name = 'jack') + rp = users.select().execute().fetchone() + + self.assert_(rp == rp) + self.assert_(not(rp != rp)) + + equal = (7, 'jack') + + self.assert_(rp == equal) + self.assert_(equal == rp) + self.assert_(not (rp != equal)) + self.assert_(not (equal != equal)) + def test_fetchmany(self): - users.insert().execute(user_id = 7, user_name = 'jack') - users.insert().execute(user_id = 8, user_name = 'ed') - users.insert().execute(user_id = 9, user_name = 'fred') - r = users.select().execute() - l = [] - for row in r.fetchmany(size=2): - l.append(row) + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'ed') + users.insert().execute(user_id = 9, user_name = 'fred') + r = users.select().execute() + l = [] + for row in r.fetchmany(size=2): + l.append(row) self.assert_(len(l) == 2, "fetchmany(size=2) got %s rows" % len(l)) - + + def test_ilike(self): + users.insert().execute( + {'user_id':1, 'user_name':'one'}, + {'user_id':2, 'user_name':'TwO'}, + {'user_id':3, 'user_name':'ONE'}, + {'user_id':4, 'user_name':'OnE'}, + ) + + self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )]) + + self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )]) + + if testing.against('postgres'): + self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )]) + self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), []) + + def test_compiled_execute(self): - users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 7, user_name = 'jack') + s = select([users], users.c.user_id==bindparam('id')).compile() + c = testing.db.connect() + assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7 + + def test_compiled_insert_execute(self): + users.insert().compile().execute(user_id = 7, user_name = 'jack') s = select([users], users.c.user_id==bindparam('id')).compile() - c = testbase.db.connect() + c = testing.db.connect() assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7 def test_repeated_bindparams(self): - """test that a BindParam can be used more than once. - this should be run for dbs with both positional and named paramstyles.""" + """Tests that a BindParam can be used more than once. + + This should be run for DB-APIs with both positional and named + paramstyles. + """ users.insert().execute(user_id = 7, user_name = 'jack') users.insert().execute(user_id = 8, user_name = 'fred') u = bindparam('userid') - s = users.select(or_(users.c.user_name==u, users.c.user_name==u)) + s = users.select(and_(users.c.user_name==u, users.c.user_name==u)) r = s.execute(userid='fred').fetchall() assert len(r) == 1 - + + u = bindparam('userid', unique=True) + s = users.select(and_(users.c.user_name==u, users.c.user_name==u)) + r = s.execute({u:'fred'}).fetchall() + assert len(r) == 1 + + def test_bindparams_in_params(self): + """test that a _BindParamClause itself can be a key in the params dict""" + + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'fred') + + u = bindparam('userid') + r = users.select(users.c.user_name==u).execute({u:'fred'}).fetchall() + assert len(r) == 1 + def test_bindparam_shortname(self): """test the 'shortname' field on BindParamClause.""" users.insert().execute(user_id = 7, user_name = 'jack') @@ -164,17 +242,48 @@ class QueryTest(PersistTest): s = users.select(users.c.user_name==u) r = s.execute(someshortname='fred').fetchall() assert len(r) == 1 - - def testdelete(self): + + def test_bindparam_detection(self): + dialect = default.DefaultDialect(paramstyle='qmark') + prep = lambda q: str(sql.text(q).compile(dialect=dialect)) + + def a_eq(got, wanted): + if got != wanted: + print "Wanted %s" % wanted + print "Received %s" % got + self.assert_(got == wanted, got) + + a_eq(prep('select foo'), 'select foo') + a_eq(prep("time='12:30:00'"), "time='12:30:00'") + a_eq(prep(u"time='12:30:00'"), u"time='12:30:00'") + a_eq(prep(":this:that"), ":this:that") + a_eq(prep(":this :that"), "? ?") + a_eq(prep("(:this),(:that :other)"), "(?),(? ?)") + a_eq(prep("(:this),(:that:other)"), "(?),(:that:other)") + a_eq(prep("(:this),(:that,:other)"), "(?),(?,?)") + a_eq(prep("(:that_:other)"), "(:that_:other)") + a_eq(prep("(:that_ :other)"), "(? ?)") + a_eq(prep("(:that_other)"), "(?)") + a_eq(prep("(:that$other)"), "(?)") + a_eq(prep("(:that$:other)"), "(:that$:other)") + a_eq(prep(".:that$ :other."), ".? ?.") + + a_eq(prep(r'select \foo'), r'select \foo') + a_eq(prep(r"time='12\:30:00'"), r"time='12\:30:00'") + a_eq(prep(":this \:that"), "? :that") + a_eq(prep(r"(\:that$other)"), "(:that$other)") + a_eq(prep(r".\:that$ :other."), ".:that$ ?.") + + def test_delete(self): users.insert().execute(user_id = 7, user_name = 'jack') users.insert().execute(user_id = 8, user_name = 'fred') print repr(users.select().execute().fetchall()) users.delete(users.c.user_name == 'fred').execute() - + print repr(users.select().execute().fetchall()) - - def testselectlimit(self): + + def test_select_limit(self): users.insert().execute(user_id=1, user_name='john') users.insert().execute(user_id=2, user_name='jack') users.insert().execute(user_id=3, user_name='ed') @@ -184,9 +293,10 @@ class QueryTest(PersistTest): users.insert().execute(user_id=7, user_name='fido') r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r)) - + @testing.unsupported('mssql') - def testselectlimitoffset(self): + @testing.fails_on('maxdb') + def test_select_limit_offset(self): users.insert().execute(user_id=1, user_name='john') users.insert().execute(user_id=2, user_name='jack') users.insert().execute(user_id=3, user_name='ed') @@ -198,51 +308,118 @@ class QueryTest(PersistTest): self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')]) r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r==[(6, 'ralph'), (7, 'fido')]) - - @testing.supported('mssql') - def testselectlimitoffset_mssql(self): - try: - r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall() - assert False # InvalidRequestError should have been raised - except exceptions.InvalidRequestError: - pass - @testing.unsupported('mysql') + @testing.exclude('mysql', '<', (5, 0, 37)) def test_scalar_select(self): """test that scalar subqueries with labels get their type propigated to the result set.""" - # mysql and/or mysqldb has a bug here, type isnt propigated for scalar subquery. - datetable = Table('datetable', metadata, + # mysql and/or mysqldb has a bug here, type isn't propagated for scalar + # subquery. + datetable = Table('datetable', metadata, Column('id', Integer, primary_key=True), Column('today', DateTime)) datetable.create() try: datetable.insert().execute(id=1, today=datetime.datetime(2006, 5, 12, 12, 0, 0)) - s = select([datetable.alias('x').c.today], scalar=True) + s = select([datetable.alias('x').c.today]).as_scalar() s2 = select([datetable.c.id, s.label('somelabel')]) #print s2.c.somelabel.type assert isinstance(s2.execute().fetchone()['somelabel'], datetime.datetime) finally: datetable.drop() - + + def test_order_by(self): + """Exercises ORDER BY clause generation. + + Tests simple, compound, aliased and DESC clauses. + """ + + users.insert().execute(user_id=1, user_name='c') + users.insert().execute(user_id=2, user_name='b') + users.insert().execute(user_id=3, user_name='a') + + def a_eq(executable, wanted): + got = list(executable.execute()) + self.assertEquals(got, wanted) + + for labels in False, True: + a_eq(users.select(order_by=[users.c.user_id], + use_labels=labels), + [(1, 'c'), (2, 'b'), (3, 'a')]) + + a_eq(users.select(order_by=[users.c.user_name, users.c.user_id], + use_labels=labels), + [(3, 'a'), (2, 'b'), (1, 'c')]) + + a_eq(select([users.c.user_id.label('foo')], + use_labels=labels, + order_by=[users.c.user_id]), + [(1,), (2,), (3,)]) + + a_eq(select([users.c.user_id.label('foo'), users.c.user_name], + use_labels=labels, + order_by=[users.c.user_name, users.c.user_id]), + [(3, 'a'), (2, 'b'), (1, 'c')]) + + a_eq(users.select(distinct=True, + use_labels=labels, + order_by=[users.c.user_id]), + [(1, 'c'), (2, 'b'), (3, 'a')]) + + a_eq(select([users.c.user_id.label('foo')], + distinct=True, + use_labels=labels, + order_by=[users.c.user_id]), + [(1,), (2,), (3,)]) + + a_eq(select([users.c.user_id.label('a'), + users.c.user_id.label('b'), + users.c.user_name], + use_labels=labels, + order_by=[users.c.user_id]), + [(1, 1, 'c'), (2, 2, 'b'), (3, 3, 'a')]) + + a_eq(users.select(distinct=True, + use_labels=labels, + order_by=[desc(users.c.user_id)]), + [(3, 'a'), (2, 'b'), (1, 'c')]) + + a_eq(select([users.c.user_id.label('foo')], + distinct=True, + use_labels=labels, + order_by=[users.c.user_id.desc()]), + [(3,), (2,), (1,)]) + def test_column_accessor(self): users.insert().execute(user_id=1, user_name='john') users.insert().execute(user_id=2, user_name='jack') addresses.insert().execute(address_id=1, user_id=2, address='foo@bar.com') - + r = users.select(users.c.user_id==2).execute().fetchone() self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') - r = text("select * from query_users where user_id=2", bind=testbase.db).execute().fetchone() + r = text("select * from query_users where user_id=2", bind=testing.db).execute().fetchone() self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') - + # test slices - r = text("select * from query_addresses", bind=testbase.db).execute().fetchone() + r = text("select * from query_addresses", bind=testing.db).execute().fetchone() self.assert_(r[0:1] == (1,)) self.assert_(r[1:] == (2, 'foo@bar.com')) self.assert_(r[:-1] == (1, 2)) - + + # test a little sqlite weirdness - with the UNION, cols come back as "query_users.user_id" in cursor.description + r = text("select query_users.user_id, query_users.user_name from query_users " + "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().fetchone() + self.assert_(r['user_id']) == 1 + self.assert_(r['user_name']) == "john" + + # test using literal tablename.colname + r = text('select query_users.user_id AS "query_users.user_id", query_users.user_name AS "query_users.user_name" from query_users', bind=testing.db).execute().fetchone() + self.assert_(r['query_users.user_id']) == 1 + self.assert_(r['query_users.user_name']) == "john" + + def test_ambiguous_column(self): users.insert().execute(user_id=1, user_name='john') r = users.outerjoin(addresses).select().execute().fetchone() @@ -250,8 +427,20 @@ class QueryTest(PersistTest): print r['user_id'] assert False except exceptions.InvalidRequestError, e: - assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." - + assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \ + str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement." + + def test_column_label_targeting(self): + users.insert().execute(user_id=7, user_name='ed') + + for s in ( + users.select().alias('foo'), + users.select().alias(users.name), + ): + row = s.select(use_labels=True).execute().fetchone() + assert row[s.c.user_id] == 7 + assert row[s.c.user_name] == 'ed' + def test_keys(self): users.insert().execute(user_id=1, user_name='foo') r = users.select().execute().fetchone() @@ -267,104 +456,21 @@ class QueryTest(PersistTest): r = users.select().execute().fetchone() self.assertEqual(len(r), 2) r.close() - r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone() + r = testing.db.execute('select user_name, user_id from query_users', {}).fetchone() self.assertEqual(len(r), 2) r.close() - r = testbase.db.execute('select user_name from query_users', {}).fetchone() + r = testing.db.execute('select user_name from query_users', {}).fetchone() self.assertEqual(len(r), 1) r.close() - + def test_cant_execute_join(self): try: users.join(addresses).execute() except exceptions.ArgumentError, e: - assert str(e) == """Not an executeable clause: query_users JOIN query_addresses ON query_users.user_id = query_addresses.user_id""" - - def test_functions(self): - x = testbase.db.func.current_date().execute().scalar() - y = testbase.db.func.current_date().select().execute().scalar() - z = testbase.db.func.current_date().scalar() - assert (x == y == z) is True - - x = testbase.db.func.current_date(type_=Date) - assert isinstance(x.type, Date) - assert isinstance(x.execute().scalar(), datetime.date) - - def test_conn_functions(self): - conn = testbase.db.connect() - try: - x = conn.execute(func.current_date()).scalar() - y = conn.execute(func.current_date().select()).scalar() - z = conn.scalar(func.current_date()) - finally: - conn.close() - assert (x == y == z) is True - - def test_update_functions(self): - """test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances, - and that column-level defaults get overridden""" - meta = MetaData(testbase.db) - t = Table('t1', meta, - Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True), - Column('value', Integer) - ) - t2 = Table('t2', meta, - Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True), - Column('value', Integer, default="7"), - Column('stuff', String(20), onupdate="thisisstuff") - ) - meta.create_all() - try: - t.insert().execute(value=func.length("one")) - assert t.select().execute().fetchone()['value'] == 3 - t.update().execute(value=func.length("asfda")) - assert t.select().execute().fetchone()['value'] == 5 - - r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() - id = r.last_inserted_ids()[0] - assert t.select(t.c.id==id).execute().fetchone()['value'] == 9 - t.update(values={t.c.value:func.length("asdf")}).execute() - assert t.select().execute().fetchone()['value'] == 4 - - t2.insert().execute() - t2.insert().execute(value=func.length("one")) - t2.insert().execute(value=func.length("asfda") + -19, stuff="hi") - - assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(7,None), (3,None), (-14,"hi")] - - t2.update().execute(value=func.length("asdsafasd"), stuff="some stuff") - assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] - - t2.delete().execute() - - t2.insert(values=dict(value=func.length("one") + 8)).execute() - assert t2.select().execute().fetchone()['value'] == 11 - - t2.update(values=dict(value=func.length("asfda"))).execute() - assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (5, "thisisstuff") - - t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute() - print "HI", select([t2.c.value, t2.c.stuff]).execute().fetchone() - assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (9, "foo") - - finally: - meta.drop_all() - - @testing.supported('postgres') - def test_functions_with_cols(self): - # TODO: shouldnt this work on oracle too ? - x = testbase.db.func.current_date().execute().scalar() - y = testbase.db.func.current_date().select().execute().scalar() - z = testbase.db.func.current_date().scalar() - w = select(['*'], from_obj=[testbase.db.func.current_date()]).scalar() - - # construct a column-based FROM object out of a function, like in [ticket:172] - s = select([column('date', type_=DateTime)], from_obj=[testbase.db.func.current_date()]) - q = s.execute().fetchone()[s.c.date] - r = s.alias('datequery').select().scalar() - - assert x == y == z == w == q == r - + assert str(e).startswith('Not an executable clause: ') + + + def test_column_order_with_simple_query(self): # should return values in column definition order users.insert().execute(user_id=1, user_name='foo') @@ -373,19 +479,19 @@ class QueryTest(PersistTest): self.assertEqual(r[1], 'foo') self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) self.assertEqual(r.values(), [1, 'foo']) - + def test_column_order_with_text_query(self): # should return values in query order users.insert().execute(user_id=1, user_name='foo') - r = testbase.db.execute('select user_name, user_id from query_users', {}).fetchone() + r = testing.db.execute('select user_name, user_id from query_users', {}).fetchone() self.assertEqual(r[0], 'foo') self.assertEqual(r[1], 1) self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id']) self.assertEqual(r.values(), ['foo', 1]) - - @testing.unsupported('oracle', 'firebird') + + @testing.unsupported('oracle', 'firebird', 'maxdb') def test_column_accessor_shadow(self): - meta = MetaData(testbase.db) + meta = MetaData(testing.db) shadowed = Table('test_shadowed', meta, Column('shadow_id', INT, primary_key = True), Column('shadow_name', VARCHAR(20)), @@ -412,137 +518,83 @@ class QueryTest(PersistTest): r.close() finally: shadowed.drop(checkfirst=True) - - @testing.supported('mssql') - def test_fetchid_trigger(self): - meta = MetaData(testbase.db) - t1 = Table('t1', meta, - Column('id', Integer, Sequence('fred', 100, 1), primary_key=True), - Column('descr', String(200))) - t2 = Table('t2', meta, - Column('id', Integer, Sequence('fred', 200, 1), primary_key=True), - Column('descr', String(200))) - meta.create_all() - con = testbase.db.connect() - con.execute("""create trigger paj on t1 for insert as - insert into t2 (descr) select descr from inserted""") - - try: - tr = con.begin() - r = con.execute(t2.insert(), descr='hello') - self.assert_(r.last_inserted_ids() == [200]) - r = con.execute(t1.insert(), descr='hello') - self.assert_(r.last_inserted_ids() == [100]) - - finally: - tr.commit() - con.execute("""drop trigger paj""") - meta.drop_all() - - @testing.supported('mssql') - def test_insertid_schema(self): - meta = MetaData(testbase.db) - con = testbase.db.connect() - con.execute('create schema paj') - tbl = Table('test', meta, Column('id', Integer, primary_key=True), schema='paj') - tbl.create() - try: - tbl.insert().execute({'id':1}) - finally: - tbl.drop() - con.execute('drop schema paj') - - @testing.supported('mssql') - def test_insertid_reserved(self): - meta = MetaData(testbase.db) - table = Table( - 'select', meta, - Column('col', Integer, primary_key=True) - ) - table.create() - - meta2 = MetaData(testbase.db) - try: - table.insert().execute(col=7) - finally: - table.drop() - + @testing.fails_on('maxdb') def test_in_filtering(self): """test the behavior of the in_() function.""" - + users.insert().execute(user_id = 7, user_name = 'jack') users.insert().execute(user_id = 8, user_name = 'fred') users.insert().execute(user_id = 9, user_name = None) - - s = users.select(users.c.user_name.in_()) + + s = users.select(users.c.user_name.in_([])) r = s.execute().fetchall() # No username is in empty set assert len(r) == 0 - - s = users.select(not_(users.c.user_name.in_())) + + s = users.select(not_(users.c.user_name.in_([]))) r = s.execute().fetchall() # All usernames with a value are outside an empty set assert len(r) == 2 - - s = users.select(users.c.user_name.in_('jack','fred')) + + s = users.select(users.c.user_name.in_(['jack','fred'])) r = s.execute().fetchall() assert len(r) == 2 - - s = users.select(not_(users.c.user_name.in_('jack','fred'))) + + s = users.select(not_(users.c.user_name.in_(['jack','fred']))) r = s.execute().fetchall() # Null values are not outside any set assert len(r) == 0 - + u = bindparam('search_key') - - s = users.select(u.in_()) + + s = users.select(u.in_([])) r = s.execute(search_key='john').fetchall() assert len(r) == 0 r = s.execute(search_key=None).fetchall() assert len(r) == 0 - - s = users.select(not_(u.in_())) + + s = users.select(not_(u.in_([]))) r = s.execute(search_key='john').fetchall() assert len(r) == 3 r = s.execute(search_key=None).fetchall() assert len(r) == 0 - - s = users.select(users.c.user_name.in_() == True) + + s = users.select(users.c.user_name.in_([]) == True) r = s.execute().fetchall() assert len(r) == 0 - s = users.select(users.c.user_name.in_() == False) + s = users.select(users.c.user_name.in_([]) == False) r = s.execute().fetchall() assert len(r) == 2 - s = users.select(users.c.user_name.in_() == None) + s = users.select(users.c.user_name.in_([]) == None) r = s.execute().fetchall() assert len(r) == 1 - -class CompoundTest(PersistTest): + +class CompoundTest(TestBase): """test compound statements like UNION, INTERSECT, particularly their ability to nest on different databases.""" def setUpAll(self): global metadata, t1, t2, t3 - metadata = MetaData(testbase.db) - t1 = Table('t1', metadata, + metadata = MetaData(testing.db) + t1 = Table('t1', metadata, Column('col1', Integer, Sequence('t1pkseq'), primary_key=True), Column('col2', String(30)), Column('col3', String(40)), Column('col4', String(30)) ) - t2 = Table('t2', metadata, + t2 = Table('t2', metadata, Column('col1', Integer, Sequence('t2pkseq'), primary_key=True), Column('col2', String(30)), Column('col3', String(40)), Column('col4', String(30))) - t3 = Table('t3', metadata, + t3 = Table('t3', metadata, Column('col1', Integer, Sequence('t3pkseq'), primary_key=True), Column('col2', String(30)), Column('col3', String(40)), Column('col4', String(30))) metadata.create_all() - + t1.insert().execute([ dict(col2="t1col2r1", col3="aaa", col4="aaa"), dict(col2="t1col2r2", col3="bbb", col4="bbb"), @@ -558,48 +610,121 @@ class CompoundTest(PersistTest): dict(col2="t3col2r2", col3="bbb", col4="aaa"), dict(col2="t3col2r3", col3="ccc", col4="bbb"), ]) - + def tearDownAll(self): metadata.drop_all() - + + def _fetchall_sorted(self, executed): + return sorted([tuple(row) for row in executed.fetchall()]) + def test_union(self): (s1, s2) = ( - select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], t1.c.col2.in_("t1col2r1", "t1col2r2")), - select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], t2.c.col2.in_("t2col2r2", "t2col2r3")) - ) + select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], + t1.c.col2.in_(["t1col2r1", "t1col2r2"])), + select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], + t2.c.col2.in_(["t2col2r2", "t2col2r3"])) + ) + u = union(s1, s2) + + wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), + ('ccc', 'aaa')] + found1 = self._fetchall_sorted(u.execute()) + self.assertEquals(found1, wanted) + + found2 = self._fetchall_sorted(u.alias('bar').select().execute()) + self.assertEquals(found2, wanted) + + def test_union_ordered(self): + (s1, s2) = ( + select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], + t1.c.col2.in_(["t1col2r1", "t1col2r2"])), + select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], + t2.c.col2.in_(["t2col2r2", "t2col2r3"])) + ) u = union(s1, s2, order_by=['col3', 'col4']) - assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - - @testing.unsupported('mysql') + + wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), + ('ccc', 'aaa')] + self.assertEquals(u.execute().fetchall(), wanted) + + @testing.fails_on('maxdb') + def test_union_ordered_alias(self): + (s1, s2) = ( + select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], + t1.c.col2.in_(["t1col2r1", "t1col2r2"])), + select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], + t2.c.col2.in_(["t2col2r2", "t2col2r3"])) + ) + u = union(s1, s2, order_by=['col3', 'col4']) + + wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), + ('ccc', 'aaa')] + self.assertEquals(u.alias('bar').select().execute().fetchall(), wanted) + + @testing.unsupported('sqlite', 'mysql', 'oracle') + def test_union_all(self): + e = union_all( + select([t1.c.col3]), + union( + select([t1.c.col3]), + select([t1.c.col3]), + ) + ) + + wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)] + found1 = self._fetchall_sorted(e.execute()) + self.assertEquals(found1, wanted) + + found2 = self._fetchall_sorted(e.alias('foo').select().execute()) + self.assertEquals(found2, wanted) + + @testing.unsupported('firebird', 'mysql', 'sybase') def test_intersect(self): i = intersect( select([t2.c.col3, t2.c.col4]), select([t2.c.col3, t2.c.col4], t2.c.col4==t3.c.col3) ) - assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - @testing.unsupported('mysql', 'oracle') + wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + + found1 = self._fetchall_sorted(i.execute()) + self.assertEquals(found1, wanted) + + found2 = self._fetchall_sorted(i.alias('bar').select().execute()) + self.assertEquals(found2, wanted) + + @testing.unsupported('firebird', 'mysql', 'oracle', 'sybase') def test_except_style1(self): e = except_(union( select([t1.c.col3, t1.c.col4]), select([t2.c.col3, t2.c.col4]), select([t3.c.col3, t3.c.col4]), ), select([t2.c.col3, t2.c.col4])) - assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] - @testing.unsupported('mysql', 'oracle') + wanted = [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), + ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] + + found = self._fetchall_sorted(e.alias('bar').select().execute()) + self.assertEquals(found, wanted) + + @testing.unsupported('firebird', 'mysql', 'oracle', 'sybase') def test_except_style2(self): e = except_(union( select([t1.c.col3, t1.c.col4]), select([t2.c.col3, t2.c.col4]), select([t3.c.col3, t3.c.col4]), ).alias('foo').select(), select([t2.c.col3, t2.c.col4])) - assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] - assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] - @testing.unsupported('sqlite', 'mysql', 'oracle') + wanted = [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), + ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] + + found1 = self._fetchall_sorted(e.execute()) + self.assertEquals(found1, wanted) + + found2 = self._fetchall_sorted(e.alias('bar').select().execute()) + self.assertEquals(found2, wanted) + + @testing.unsupported('firebird', 'mysql', 'oracle', 'sqlite', 'sybase') def test_except_style3(self): # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc e = except_( @@ -610,19 +735,10 @@ class CompoundTest(PersistTest): ) ) self.assertEquals(e.execute().fetchall(), [('ccc',)]) + self.assertEquals(e.alias('foo').select().execute().fetchall(), + [('ccc',)]) - @testing.unsupported('sqlite', 'mysql', 'oracle') - def test_union_union_all(self): - e = union_all( - select([t1.c.col3]), - union( - select([t1.c.col3]), - select([t1.c.col3]), - ) - ) - self.assertEquals(e.execute().fetchall(), [('aaa',),('bbb',),('ccc',),('aaa',),('bbb',),('ccc',)]) - - @testing.unsupported('mysql') + @testing.unsupported('firebird', 'mysql') def test_composite(self): u = intersect( select([t2.c.col3, t2.c.col4]), @@ -632,20 +748,308 @@ class CompoundTest(PersistTest): select([t3.c.col3, t3.c.col4]), ).alias('foo').select() ) - assert u.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - assert u.alias('foo').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + found = self._fetchall_sorted(u.execute()) + + self.assertEquals(found, wanted) + + @testing.unsupported('firebird', 'mysql') + def test_composite_alias(self): + ua = intersect( + select([t2.c.col3, t2.c.col4]), + union( + select([t1.c.col3, t1.c.col4]), + select([t2.c.col3, t2.c.col4]), + select([t3.c.col3, t3.c.col4]), + ).alias('foo').select() + ).alias('bar') + + wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + found = self._fetchall_sorted(ua.select().execute()) + self.assertEquals(found, wanted) + + +class JoinTest(TestBase): + """Tests join execution. + + The compiled SQL emitted by the dialect might be ANSI joins or + theta joins ('old oracle style', with (+) for OUTER). This test + tries to exercise join syntax and uncover any inconsistencies in + `JOIN rhs ON lhs.col=rhs.col` vs `rhs.col=lhs.col`. At least one + database seems to be sensitive to this. + """ -class OperatorTest(PersistTest): + def setUpAll(self): + global metadata + global t1, t2, t3 + + metadata = MetaData(testing.db) + t1 = Table('t1', metadata, + Column('t1_id', Integer, primary_key=True), + Column('name', String(32))) + t2 = Table('t2', metadata, + Column('t2_id', Integer, primary_key=True), + Column('t1_id', Integer, ForeignKey('t1.t1_id')), + Column('name', String(32))) + t3 = Table('t3', metadata, + Column('t3_id', Integer, primary_key=True), + Column('t2_id', Integer, ForeignKey('t2.t2_id')), + Column('name', String(32))) + metadata.drop_all() + metadata.create_all() + + # t1.10 -> t2.20 -> t3.30 + # t1.11 -> t2.21 + # t1.12 + t1.insert().execute({'t1_id': 10, 'name': 't1 #10'}, + {'t1_id': 11, 'name': 't1 #11'}, + {'t1_id': 12, 'name': 't1 #12'}) + t2.insert().execute({'t2_id': 20, 't1_id': 10, 'name': 't2 #20'}, + {'t2_id': 21, 't1_id': 11, 'name': 't2 #21'}) + t3.insert().execute({'t3_id': 30, 't2_id': 20, 'name': 't3 #30'}) + + def tearDownAll(self): + metadata.drop_all() + + def assertRows(self, statement, expected): + """Execute a statement and assert that rows returned equal expected.""" + + found = sorted([tuple(row) + for row in statement.execute().fetchall()]) + + self.assertEquals(found, sorted(expected)) + + def test_join_x1(self): + """Joins t1->t2.""" + + for criteria in (t1.c.t1_id==t2.c.t1_id, t2.c.t1_id==t1.c.t1_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id], + from_obj=[t1.join(t2, criteria)]) + self.assertRows(expr, [(10, 20), (11, 21)]) + + def test_join_x2(self): + """Joins t1->t2->t3.""" + + for criteria in (t1.c.t1_id==t2.c.t1_id, t2.c.t1_id==t1.c.t1_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id], + from_obj=[t1.join(t2, criteria)]) + self.assertRows(expr, [(10, 20), (11, 21)]) + + def test_outerjoin_x1(self): + """Outer joins t1->t2.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id], + from_obj=[t1.join(t2).join(t3, criteria)]) + self.assertRows(expr, [(10, 20)]) + + def test_outerjoin_x2(self): + """Outer joins t1->t2,t3.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + from_obj=[t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). \ + outerjoin(t3, criteria)]) + self.assertRows(expr, [(10, 20, 30), (11, 21, None), (12, None, None)]) + + def test_outerjoin_where_x2_t1(self): + """Outer joins t1->t2,t3, where on t1.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t1.c.name == 't1 #10', + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t1.c.t1_id < 12, + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) + + def test_outerjoin_where_x2_t2(self): + """Outer joins t1->t2,t3, where on t2.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t2.c.name == 't2 #20', + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t2.c.t2_id < 29, + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) + + def test_outerjoin_where_x2_t1t2(self): + """Outer joins t1->t2,t3, where on t1 and t2.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'), + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.t1_id < 19, 29 > t2.c.t2_id), + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) + + def test_outerjoin_where_x2_t3(self): + """Outer joins t1->t2,t3, where on t3.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t3.c.name == 't3 #30', + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t3.c.t3_id < 39, + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + def test_outerjoin_where_x2_t1t3(self): + """Outer joins t1->t2,t3, where on t1 and t3.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.name == 't1 #10', t3.c.name == 't3 #30'), + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.t1_id < 19, t3.c.t3_id < 39), + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + def test_outerjoin_where_x2_t1t2(self): + """Outer joins t1->t2,t3, where on t1 and t2.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'), + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.t1_id < 12, t2.c.t2_id < 39), + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) + + def test_outerjoin_where_x2_t1t2t3(self): + """Outer joins t1->t2,t3, where on t1, t2 and t3.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.name == 't1 #10', + t2.c.name == 't2 #20', + t3.c.name == 't3 #30'), + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.t1_id < 19, + t2.c.t2_id < 29, + t3.c.t3_id < 39), + from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). + outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + def test_mixed(self): + """Joins t1->t2, outer t2->t3.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + print expr + self.assertRows(expr, [(10, 20, 30), (11, 21, None)]) + + def test_mixed_where(self): + """Joins t1->t2, outer t2->t3, plus a where on each table in turn.""" + + for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id): + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t1.c.name == 't1 #10', + from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t2.c.name == 't2 #20', + from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + t3.c.name == 't3 #30', + from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'), + from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t2.c.name == 't2 #20', t3.c.name == 't3 #30'), + from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + expr = select( + [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id], + and_(t1.c.name == 't1 #10', + t2.c.name == 't2 #20', + t3.c.name == 't3 #30'), + from_obj=[(t1.join(t2).outerjoin(t3, criteria))]) + self.assertRows(expr, [(10, 20, 30)]) + + +class OperatorTest(TestBase): def setUpAll(self): global metadata, flds - metadata = MetaData(testbase.db) - flds = Table('flds', metadata, + metadata = MetaData(testing.db) + flds = Table('flds', metadata, Column('idcol', Integer, Sequence('t1pkseq'), primary_key=True), Column('intcol', Integer), Column('strcol', String(50)), ) metadata.create_all() - + flds.insert().execute([ dict(intcol=5, strcol='foo'), dict(intcol=13, strcol='bar') @@ -653,12 +1057,16 @@ class OperatorTest(PersistTest): def tearDownAll(self): metadata.drop_all() - + + @testing.fails_on('maxdb') def test_modulo(self): self.assertEquals( - select([flds.c.intcol % 3], order_by=flds.c.idcol).execute().fetchall(), + select([flds.c.intcol % 3], + order_by=flds.c.idcol).execute().fetchall(), [(2,),(1,)] ) - + + + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/sql/quote.py b/test/sql/quote.py index 2fdf9dba0c..825e836ff8 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -1,15 +1,16 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import sql +from sqlalchemy.sql import compiler from testlib import * - -class QuoteTest(PersistTest): +class QuoteTest(TestBase): def setUpAll(self): # TODO: figure out which databases/which identifiers allow special characters to be used, # such as: spaces, quote characters, punctuation characters, set up tests for those as # well. global table1, table2, table3 - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) table1 = Table('WorstCase1', metadata, Column('lowercase', Integer, primary_key=True), Column('UPPERCASE', Integer), @@ -21,15 +22,15 @@ class QuoteTest(PersistTest): Column('MixedCase', Integer)) table1.create() table2.create() - + def tearDown(self): table1.delete().execute() table2.delete().execute() - + def tearDownAll(self): table1.drop() table2.drop() - + def testbasic(self): table1.insert().execute({'lowercase':1,'UPPERCASE':2,'MixedCase':3,'a123':4}, {'lowercase':2,'UPPERCASE':2,'MixedCase':3,'a123':4}, @@ -37,19 +38,19 @@ class QuoteTest(PersistTest): table2.insert().execute({'d123':1,'u123':2,'MixedCase':3}, {'d123':2,'u123':2,'MixedCase':3}, {'d123':4,'u123':3,'MixedCase':2}) - + res1 = select([table1.c.lowercase, table1.c.UPPERCASE, table1.c.MixedCase, table1.c.a123]).execute().fetchall() print res1 assert(res1==[(1,2,3,4),(2,2,3,4),(4,3,2,1)]) - + res2 = select([table2.c.d123, table2.c.u123, table2.c.MixedCase]).execute().fetchall() print res2 assert(res2==[(1,2,3),(2,2,3),(4,3,2)]) - + def testreflect(self): - meta2 = MetaData(testbase.db) + meta2 = MetaData(testing.db) t2 = Table('WorstCase2', meta2, autoload=True, quote=True) - assert t2.c.has_key('MixedCase') + assert 'MixedCase' in t2.c def testlabels(self): table1.insert().execute({'lowercase':1,'UPPERCASE':2,'MixedCase':3,'a123':4}, @@ -58,43 +59,31 @@ class QuoteTest(PersistTest): table2.insert().execute({'d123':1,'u123':2,'MixedCase':3}, {'d123':2,'u123':2,'MixedCase':3}, {'d123':4,'u123':3,'MixedCase':2}) - + res1 = select([table1.c.lowercase, table1.c.UPPERCASE, table1.c.MixedCase, table1.c.a123], use_labels=True).execute().fetchall() print res1 assert(res1==[(1,2,3,4),(2,2,3,4),(4,3,2,1)]) - + res2 = select([table2.c.d123, table2.c.u123, table2.c.MixedCase], use_labels=True).execute().fetchall() print res2 assert(res2==[(1,2,3),(2,2,3),(4,3,2)]) - - def testcascade(self): - lcmetadata = MetaData(case_sensitive=False) - t1 = Table('SomeTable', lcmetadata, - Column('UcCol', Integer), - Column('normalcol', String)) - t2 = Table('othertable', lcmetadata, - Column('UcCol', Integer), - Column('normalcol', String, ForeignKey('SomeTable.normalcol'))) - assert lcmetadata.case_sensitive is False - assert t1.c.UcCol.case_sensitive is False - assert t2.c.normalcol.case_sensitive is False - - @testing.unsupported('oracle') + + @testing.unsupported('oracle') def testlabels(self): """test the quoting of labels. - + if labels arent quoted, a query in postgres in particular will fail since it produces: - - SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC" + + SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC" FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, "WorstCase1"."UPPERCASE" AS UPPERCASE, "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa - + where the "UPPERCASE" column of "LaLa" doesnt exist. """ x = table1.select(distinct=True).alias("LaLa").select().scalar() def testlabels2(self): metadata = MetaData() - table = Table("ImATable", metadata, + table = Table("ImATable", metadata, Column("col1", Integer)) x = select([table.c.col1.label("ImATable_col1")]).alias("SomeAlias") assert str(select([x.c.ImATable_col1])) == '''SELECT "SomeAlias"."ImATable_col1" \nFROM (SELECT "ImATable".col1 AS "ImATable_col1" \nFROM "ImATable") AS "SomeAlias"''' @@ -103,40 +92,61 @@ class QuoteTest(PersistTest): x = select([sql.literal_column("'foo'").label("somelabel")], from_obj=[table]).alias("AnAlias") x = x.select() assert str(x) == '''SELECT "AnAlias".somelabel \nFROM (SELECT 'foo' AS somelabel \nFROM "ImATable") AS "AnAlias"''' - + x = select([sql.literal_column("'FooCol'").label("SomeLabel")], from_obj=[table]) x = x.select() assert str(x) == '''SELECT "SomeLabel" \nFROM (SELECT 'FooCol' AS "SomeLabel" \nFROM "ImATable")''' - - def testlabelsnocase(self): - metadata = MetaData() - table1 = Table('SomeCase1', metadata, - Column('lowercase', Integer, primary_key=True), - Column('UPPERCASE', Integer), - Column('MixedCase', Integer)) - table2 = Table('SomeCase2', metadata, - Column('id', Integer, primary_key=True, key='d123'), - Column('col2', Integer, key='u123'), - Column('MixedCase', Integer)) - - # first test case sensitive tables migrating via tometadata - meta = MetaData(testbase.db, case_sensitive=False) - lc_table1 = table1.tometadata(meta) - lc_table2 = table2.tometadata(meta) - assert lc_table1.case_sensitive is False - assert lc_table1.c.UPPERCASE.case_sensitive is False - s = lc_table1.select() - assert hasattr(s.c.UPPERCASE, "case_sensitive") - assert s.c.UPPERCASE.case_sensitive is False - - # now, the aliases etc. should be case-insensitive. PG will screw up if this doesnt work. - # also, if this test is run in the context of the other tests, we also test that the dialect properly - # caches identifiers with "case_sensitive" and "not case_sensitive" separately. - meta.create_all() - try: - x = lc_table1.select(distinct=True).alias("lala").select().scalar() - finally: - meta.drop_all() - + + +class PreparerTest(TestBase): + """Test the db-agnostic quoting services of IdentifierPreparer.""" + + def test_unformat(self): + prep = compiler.IdentifierPreparer(None) + unformat = prep.unformat_identifiers + + def a_eq(have, want): + if have != want: + print "Wanted %s" % want + print "Received %s" % have + self.assert_(have == want) + + a_eq(unformat('foo'), ['foo']) + a_eq(unformat('"foo"'), ['foo']) + a_eq(unformat("'foo'"), ["'foo'"]) + a_eq(unformat('foo.bar'), ['foo', 'bar']) + a_eq(unformat('"foo"."bar"'), ['foo', 'bar']) + a_eq(unformat('foo."bar"'), ['foo', 'bar']) + a_eq(unformat('"foo".bar'), ['foo', 'bar']) + a_eq(unformat('"foo"."b""a""r"."baz"'), ['foo', 'b"a"r', 'baz']) + + def test_unformat_custom(self): + class Custom(compiler.IdentifierPreparer): + def __init__(self, dialect): + super(Custom, self).__init__(dialect, initial_quote='`', + final_quote='`') + def _escape_identifier(self, value): + return value.replace('`', '``') + def _unescape_identifier(self, value): + return value.replace('``', '`') + + prep = Custom(None) + unformat = prep.unformat_identifiers + + def a_eq(have, want): + if have != want: + print "Wanted %s" % want + print "Received %s" % have + self.assert_(have == want) + + a_eq(unformat('foo'), ['foo']) + a_eq(unformat('`foo`'), ['foo']) + a_eq(unformat(`'foo'`), ["'foo'"]) + a_eq(unformat('foo.bar'), ['foo', 'bar']) + a_eq(unformat('`foo`.`bar`'), ['foo', 'bar']) + a_eq(unformat('foo.`bar`'), ['foo', 'bar']) + a_eq(unformat('`foo`.bar'), ['foo', 'bar']) + a_eq(unformat('`foo`.`b``a``r`.`baz`'), ['foo', 'b`a`r', 'baz']) + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py index e0da96a81d..3c9caad754 100644 --- a/test/sql/rowcount.py +++ b/test/sql/rowcount.py @@ -1,17 +1,17 @@ -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * from testlib import * -class FoundRowsTest(AssertMixin): +class FoundRowsTest(TestBase, AssertsExecutionResults): """tests rowcount functionality""" def setUpAll(self): - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) global employees_table employees_table = Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, Sequence('employee_id_seq', optional=True), primary_key=True), Column('name', String(50)), Column('department', String(1)), ) @@ -47,26 +47,25 @@ class FoundRowsTest(AssertMixin): # WHERE matches 3, 3 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='Z') - if testbase.db.dialect.supports_sane_rowcount(): + print "expecting 3, dialect reports %s" % r.rowcount + if testing.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 def test_update_rowcount2(self): # WHERE matches 3, 0 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='C') - if testbase.db.dialect.supports_sane_rowcount(): + print "expecting 3, dialect reports %s" % r.rowcount + if testing.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 def test_delete_rowcount(self): # WHERE matches 3, 3 rows deleted department = employees_table.c.department r = employees_table.delete(department=='C').execute() - if testbase.db.dialect.supports_sane_rowcount(): + print "expecting 3, dialect reports %s" % r.rowcount + if testing.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 if __name__ == '__main__': - testbase.main() - - - - + testenv.main() diff --git a/test/sql/select.py b/test/sql/select.py index a5cf061e21..bea8621121 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -1,28 +1,25 @@ -import testbase -import re, operator +import testenv; testenv.configure_for_tests() +import datetime, re, operator from sqlalchemy import * +from sqlalchemy import exceptions, sql, util +from sqlalchemy.sql import table, column, compiler from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql from testlib import * - -# the select test now tests almost completely with TableClause/ColumnClause objects, -# which are free-roaming table/column objects not attached to any database. -# so SQLAlchemy's SQL construction engine can be used with no database dependencies at all. - -table1 = table('mytable', +table1 = table('mytable', column('myid', Integer), column('name', String), column('description', String), ) table2 = table( - 'myothertable', + 'myothertable', column('otherid', Integer), column('othername', String), ) table3 = table( - 'thirdtable', + 'thirdtable', column('userid', Integer), column('otherstuff', String), ) @@ -36,13 +33,13 @@ table4 = Table( schema = 'remote_owner' ) -users = table('users', +users = table('users', column('user_id'), column('user_name'), column('password'), ) -addresses = table('addresses', +addresses = table('addresses', column('address_id'), column('user_id'), column('street'), @@ -51,65 +48,69 @@ addresses = table('addresses', column('zip') ) -class SQLTest(PersistTest): - def runtest(self, clause, result, dialect = None, params = None, checkparams = None): - c = clause.compile(parameters=params, dialect=dialect) - print "\nSQL String:\n" + str(c) + repr(c.get_params()) - cc = re.sub(r'\n', '', str(c)) - self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'") - if checkparams is not None: - if isinstance(checkparams, list): - self.assert_(c.get_params().get_raw_list() == checkparams, "params dont match ") - else: - self.assert_(c.get_params().get_original_dict() == checkparams, "params dont match" + repr(c.get_params())) - -class SelectTest(SQLTest): - def testtableselect(self): - self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable") - - self.runtest(select([table1, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \ +class SelectTest(TestBase, AssertsCompiledSQL): + + def test_attribute_sanity(self): + assert hasattr(table1, 'c') + assert hasattr(table1.select(), 'c') + assert not hasattr(table1.c.myid.self_group(), 'columns') + assert hasattr(table1.select().self_group(), 'columns') + assert not hasattr(select([table1.c.myid]).as_scalar().self_group(), 'columns') + assert not hasattr(table1.c.myid, 'columns') + assert not hasattr(table1.c.myid, 'c') + assert not hasattr(table1.select().c.myid, 'c') + assert not hasattr(table1.select().c.myid, 'columns') + assert not hasattr(table1.alias().c.myid, 'columns') + assert not hasattr(table1.alias().c.myid, 'c') + + def test_table_select(self): + self.assert_compile(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable") + + self.assert_compile(select([table1, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \ myothertable.othername FROM mytable, myothertable") - def testselectselect(self): + def test_from_subquery(self): """tests placing select statements in the column clause of another select, for the purposes of selecting from the exported columns of that select.""" + s = select([table1], table1.c.name == 'jack') - self.runtest( + self.assert_compile( select( [s], s.c.myid == 7 ) , - "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable WHERE mytable.name = :mytable_name) WHERE myid = :myid") - + "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable "\ + "WHERE mytable.name = :name_1) WHERE myid = :myid_1") + sq = select([table1]) - self.runtest( + self.assert_compile( sq.select(), "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)" ) - + sq = select( [table1], ).alias('sq') - self.runtest( - sq.select(sq.c.myid == 7), + self.assert_compile( + sq.select(sq.c.myid == 7), "SELECT sq.myid, sq.name, sq.description FROM \ -(SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.myid = :sq_myid" +(SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.myid = :myid_1" ) - + sq = select( [table1, table2], and_(table1.c.myid ==7, table2.c.otherid==table1.c.myid), use_labels = True ).alias('sq') - + sqstring = "SELECT mytable.myid AS mytable_myid, mytable.name AS mytable_name, \ mytable.description AS mytable_description, myothertable.otherid AS myothertable_otherid, \ myothertable.othername AS myothertable_othername FROM mytable, myothertable \ -WHERE mytable.myid = :mytable_myid AND myothertable.otherid = mytable.myid" +WHERE mytable.myid = :myid_1 AND myothertable.otherid = mytable.myid" - self.runtest(sq.select(), "SELECT sq.mytable_myid, sq.mytable_name, sq.mytable_description, sq.myothertable_otherid, \ + self.assert_compile(sq.select(), "SELECT sq.mytable_myid, sq.mytable_name, sq.mytable_description, sq.myothertable_otherid, \ sq.myothertable_othername FROM (" + sqstring + ") AS sq") sq2 = select( @@ -117,111 +118,173 @@ sq.myothertable_othername FROM (" + sqstring + ") AS sq") use_labels = True ).alias('sq2') - self.runtest(sq2.select(), "SELECT sq2.sq_mytable_myid, sq2.sq_mytable_name, sq2.sq_mytable_description, \ + self.assert_compile(sq2.select(), "SELECT sq2.sq_mytable_myid, sq2.sq_mytable_name, sq2.sq_mytable_description, \ sq2.sq_myothertable_otherid, sq2.sq_myothertable_othername FROM \ (SELECT sq.mytable_myid AS sq_mytable_myid, sq.mytable_name AS sq_mytable_name, \ sq.mytable_description AS sq_mytable_description, sq.myothertable_otherid AS sq_myothertable_otherid, \ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") AS sq) AS sq2") - - def testmssql_noorderbyinsubquery(self): - """test that the ms-sql dialect removes ORDER BY clauses from subqueries""" - dialect = mssql.dialect() - q = select([table1.c.myid], order_by=[table1.c.myid]).alias('foo') - crit = q.c.myid == table1.c.myid - self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable ORDER BY mytable.myid) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=sqlite.dialect()) - self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=mssql.dialect()) - - def testmssql_aliases_schemas(self): - self.runtest(table4.select(), "SELECT remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM remote_owner.remotetable") - - dialect = mssql.dialect() - self.runtest(table4.select(), "SELECT remotetable_1.rem_id, remotetable_1.datatype_id, remotetable_1.value FROM remote_owner.remotetable AS remotetable_1", dialect=dialect) - - # TODO: this is probably incorrect; no "AS " is being applied to the table - self.runtest(table1.join(table4, table1.c.myid==table4.c.rem_id).select(), "SELECT mytable.myid, mytable.name, mytable.description, remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM mytable JOIN remote_owner.remotetable ON remotetable.rem_id = mytable.myid") + + def test_nested_uselabels(self): + """test nested anonymous label generation. this + essentially tests the ANONYMOUS_LABEL regex. + + """ + s1 = table1.select() + s2 = s1.alias() + s3 = select([s2], use_labels=True) + s4 = s3.alias() + s5 = select([s4], use_labels=True) + self.assert_compile(s5, "SELECT anon_1.anon_2_myid AS anon_1_anon_2_myid, anon_1.anon_2_name AS anon_1_anon_2_name, "\ + "anon_1.anon_2_description AS anon_1_anon_2_description FROM (SELECT anon_2.myid AS anon_2_myid, anon_2.name AS anon_2_name, "\ + "anon_2.description AS anon_2_description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description "\ + "AS description FROM mytable) AS anon_2) AS anon_1") - def testdontovercorrelate(self): - self.runtest(select([table1], from_obj=[table1, table1.select()]), """SELECT mytable.myid, mytable.name, mytable.description FROM mytable, (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)""") + def test_dont_overcorrelate(self): + self.assert_compile(select([table1], from_obj=[table1, table1.select()]), """SELECT mytable.myid, mytable.name, mytable.description FROM mytable, (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)""") - def testexistsascolumnclause(self): - self.runtest(exists([table1.c.myid], table1.c.myid==5).select(), "SELECT EXISTS (SELECT mytable.myid AS myid FROM mytable WHERE mytable.myid = :mytable_myid)", params={'mytable_myid':5}) + def test_full_correlate(self): + # intentional + t = table('t', column('a'), column('b')) + s = select([t.c.a]).where(t.c.a==1).correlate(t).as_scalar() - self.runtest(select([table1, exists([1], from_obj=[table2])]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) FROM mytable", params={}) - - self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={}) + s2 = select([t.c.a, s]) + self.assert_compile(s2, """SELECT t.a, (SELECT t.a WHERE t.a = :a_1) AS anon_1 FROM t""") + + # unintentional + t2 = table('t2', column('c'), column('d')) + s = select([t.c.a]).where(t.c.a==t2.c.d).as_scalar() + s2 =select([t, t2, s]) + self.assertRaises(exceptions.InvalidRequestError, str, s2) + + # intentional again + s = s.correlate(t, t2) + s2 =select([t, t2, s]) + self.assert_compile(s, "SELECT t.a WHERE t.a = t2.d") - def testwheresubquery(self): + def test_exists(self): + self.assert_compile(exists([table1.c.myid], table1.c.myid==5).select(), "SELECT EXISTS (SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)", params={'mytable_myid':5}) + + self.assert_compile(select([table1, exists([1], from_obj=table2)]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) FROM mytable", params={}) + + self.assert_compile(select([table1, exists([1], from_obj=table2).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={}) + + self.assert_compile( + table1.select(exists([1], table2.c.otherid == table1.c.myid).correlate(table1)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)" + ) + + self.assert_compile( + table1.select(exists([1]).where(table2.c.otherid == table1.c.myid).correlate(table1)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)" + ) + + self.assert_compile( + table1.select(exists([1]).where(table2.c.otherid == table1.c.myid).correlate(table1)).replace_selectable(table2, table2.alias()), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable AS myothertable_1 WHERE myothertable_1.otherid = mytable.myid)" + ) + + self.assert_compile( + table1.select(exists([1]).where(table2.c.otherid == table1.c.myid).correlate(table1)).select_from(table1.join(table2, table1.c.myid==table2.c.otherid)).replace_selectable(table2, table2.alias()), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable JOIN myothertable AS myothertable_1 ON mytable.myid = myothertable_1.otherid WHERE EXISTS (SELECT 1 FROM myothertable AS myothertable_1 WHERE myothertable_1.otherid = mytable.myid)" + ) + + def test_where_subquery(self): s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s') - self.runtest( - select([users, s.c.street], from_obj=[s]), + self.assert_compile( + select([users, s.c.street], from_obj=s), """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""") - # TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet. - #self.runtest( - # table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), "" - #) - - self.runtest( + self.assert_compile( + table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT mytable.myid FROM mytable WHERE mytable.name = :name_1)" + ) + + self.assert_compile( table1.select(table1.c.myid == select([table2.c.otherid], table1.c.name == table2.c.othername)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid FROM myothertable WHERE mytable.name = myothertable.othername)" ) - self.runtest( + self.assert_compile( table1.select(exists([1], table2.c.otherid == table1.c.myid)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)" ) + talias = table1.alias('ta') s = subquery('sq2', [talias], exists([1], table2.c.otherid == talias.c.myid)) - self.runtest( + self.assert_compile( select([s, table1]) ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS myid, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable") s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s') - self.runtest( - select([users, s.c.street], from_obj=[s]), + self.assert_compile( + select([users, s.c.street], from_obj=s), """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""") - + # test constructing the outer query via append_column(), which occurs in the ORM's Query object - s = select([], exists([1], table2.c.otherid==table1.c.myid), from_obj=[table1]) + s = select([], exists([1], table2.c.otherid==table1.c.myid), from_obj=table1) s.append_column(table1) - self.runtest( + self.assert_compile( s, "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)" ) - - def testorderbysubquery(self): - self.runtest( + + def test_orderby_subquery(self): + self.assert_compile( table1.select(order_by=[select([table2.c.otherid], table1.c.myid==table2.c.otherid)]), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable ORDER BY (SELECT myothertable.otherid FROM myothertable WHERE mytable.myid = myothertable.otherid)" ) - self.runtest( + self.assert_compile( table1.select(order_by=[desc(select([table2.c.otherid], table1.c.myid==table2.c.otherid))]), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable ORDER BY (SELECT myothertable.otherid FROM myothertable WHERE mytable.myid = myothertable.otherid) DESC" ) - - - def testcolumnsubquery(self): + + @testing.uses_deprecated('scalar option') + def test_scalar_select(self): + try: + s = select([table1.c.myid, table1.c.name]).as_scalar() + assert False + except exceptions.InvalidRequestError, err: + assert str(err) == "Scalar select can only be created from a Select object that has exactly one column expression.", str(err) + + try: + # generic function which will look at the type of expression + func.coalesce(select([table1.c.myid])) + assert False + except exceptions.InvalidRequestError, err: + assert str(err) == "Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.", str(err) + s = select([table1.c.myid], scalar=True, correlate=False) - self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable") + self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) AS anon_1 FROM mytable") s = select([table1.c.myid], scalar=True) - self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable") + self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) AS anon_1 FROM myothertable") s = select([table1.c.myid]).correlate(None).as_scalar() - self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable") + self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) AS anon_1 FROM mytable") s = select([table1.c.myid]).as_scalar() - self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable") + self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) AS anon_1 FROM myothertable") # test expressions against scalar selects - self.runtest(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal") - self.runtest(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal") - self.runtest(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal") + self.assert_compile(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :param_1 AS anon_1") + self.assert_compile(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :param_1 AS anon_1") + self.assert_compile(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :param_1 AS anon_1") - self.runtest(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo") + self.assert_compile(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo") + + # scalar selects should not have any attributes on their 'c' or 'columns' attribute + s = select([table1.c.myid]).as_scalar() + try: + s.c.foo + except exceptions.InvalidRequestError, err: + assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.' + try: + s.columns.foo + except exceptions.InvalidRequestError, err: + assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.' zips = table('zips', column('zipcode'), @@ -235,78 +298,90 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A zip = '12345' qlat = select([zips.c.latitude], zips.c.zipcode == zip).correlate(None).as_scalar() qlng = select([zips.c.longitude], zips.c.zipcode == zip).correlate(None).as_scalar() - + q = select([places.c.id, places.c.nm, zips.c.zipcode, func.latlondist(qlat, qlng).label('dist')], zips.c.zipcode==zip, order_by = ['dist', places.c.nm] ) - self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE " - "zips.zipcode = :zips_zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_1)) AS dist " - "FROM places, zips WHERE zips.zipcode = :zips_zipcode_2 ORDER BY dist, places.nm") - + self.assert_compile(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE " + "zips.zipcode = :zipcode_1), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zipcode_2)) AS dist " + "FROM places, zips WHERE zips.zipcode = :zipcode_3 ORDER BY dist, places.nm") + zalias = zips.alias('main_zip') qlat = select([zips.c.latitude], zips.c.zipcode == zalias.c.zipcode, scalar=True) qlng = select([zips.c.longitude], zips.c.zipcode == zalias.c.zipcode, scalar=True) q = select([places.c.id, places.c.nm, zalias.c.zipcode, func.latlondist(qlat, qlng).label('dist')], order_by = ['dist', places.c.nm] ) - self.runtest(q, "SELECT places.id, places.nm, main_zip.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = main_zip.zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = main_zip.zipcode)) AS dist FROM places, zips AS main_zip ORDER BY dist, places.nm") + self.assert_compile(q, "SELECT places.id, places.nm, main_zip.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = main_zip.zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = main_zip.zipcode)) AS dist FROM places, zips AS main_zip ORDER BY dist, places.nm") a1 = table2.alias('t2alias') s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True) j1 = table1.join(table2, table1.c.myid==table2.c.otherid) - s2 = select([table1, s1], from_obj=[j1]) - self.runtest(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid") - - def testlabelcomparison(self): + s2 = select([table1, s1], from_obj=j1) + self.assert_compile(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) AS anon_1 FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid") + + def test_label_comparison(self): x = func.lala(table1.c.myid).label('foo') - self.runtest(select([x], x==5), "SELECT lala(mytable.myid) AS foo FROM mytable WHERE lala(mytable.myid) = :literal") - - def testand(self): - self.runtest( - select(['*'], and_(table1.c.myid == 12, table1.c.name=='asdf', table2.c.othername == 'foo', "sysdate() = today()")), - "SELECT * FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name AND myothertable.othername = :myothertable_othername AND sysdate() = today()" + self.assert_compile(select([x], x==5), "SELECT lala(mytable.myid) AS foo FROM mytable WHERE lala(mytable.myid) = :param_1") + + def test_conjunctions(self): + self.assert_compile( + and_(table1.c.myid == 12, table1.c.name=='asdf', table2.c.othername == 'foo', "sysdate() = today()"), + "mytable.myid = :myid_1 AND mytable.name = :name_1 "\ + "AND myothertable.othername = :othername_1 AND sysdate() = today()" ) - def testor(self): - self.runtest( - select([table1], and_( + self.assert_compile( + and_( table1.c.myid == 12, or_(table2.c.othername=='asdf', table2.c.othername == 'foo', table2.c.otherid == 9), - "sysdate() = today()", - )), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND (myothertable.othername = :myothertable_othername OR myothertable.othername = :myothertable_othername_1 OR myothertable.otherid = :myothertable_otherid) AND sysdate() = today()", - checkparams = {'myothertable_othername': 'asdf', 'myothertable_othername_1':'foo', 'myothertable_otherid': 9, 'mytable_myid': 12} + "sysdate() = today()", + ), + "mytable.myid = :myid_1 AND (myothertable.othername = :othername_1 OR "\ + "myothertable.othername = :othername_2 OR myothertable.otherid = :otherid_1) AND sysdate() = today()", + checkparams = {'othername_1': 'asdf', 'othername_2':'foo', 'otherid_1': 9, 'myid_1': 12} ) - def testdistinct(self): - self.runtest( + def test_distinct(self): + self.assert_compile( select([table1.c.myid.distinct()]), "SELECT DISTINCT mytable.myid FROM mytable" ) - self.runtest( + self.assert_compile( select([distinct(table1.c.myid)]), "SELECT DISTINCT mytable.myid FROM mytable" ) - - def testoperators(self): - # exercise arithmetic operators + self.assert_compile( + select([table1.c.myid]).distinct(), "SELECT DISTINCT mytable.myid FROM mytable" + ) + + self.assert_compile( + select([func.count(table1.c.myid.distinct())]), "SELECT count(DISTINCT mytable.myid) AS count_1 FROM mytable" + ) + + self.assert_compile( + select([func.count(distinct(table1.c.myid))]), "SELECT count(DISTINCT mytable.myid) AS count_1 FROM mytable" + ) + + def test_operators(self): for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), (operator.sub, '-'), (operator.div, '/'), ): for (lhs, rhs, res) in ( - (5, table1.c.myid, ':mytable_myid %s mytable.myid'), - (5, literal(5), ':literal %s :literal_1'), - (table1.c.myid, 'b', 'mytable.myid %s :mytable_myid'), - (table1.c.myid, literal(2.7), 'mytable.myid %s :literal'), + (5, table1.c.myid, ':myid_1 %s mytable.myid'), + (5, literal(5), ':param_1 %s :param_2'), + (table1.c.myid, 'b', 'mytable.myid %s :myid_1'), + (table1.c.myid, literal(2.7), 'mytable.myid %s :param_1'), (table1.c.myid, table1.c.myid, 'mytable.myid %s mytable.myid'), - (literal(5), 8, ':literal %s :literal_1'), - (literal(6), table1.c.myid, ':literal %s mytable.myid'), - (literal(7), literal(5.5), ':literal %s :literal_1'), + (literal(5), 8, ':param_1 %s :param_2'), + (literal(6), table1.c.myid, ':param_1 %s mytable.myid'), + (literal(7), literal(5.5), ':param_1 %s :param_2'), ): - self.runtest(py_op(lhs, rhs), res % sql_op) + self.assert_compile(py_op(lhs, rhs), res % sql_op) + dt = datetime.datetime.today() # exercise comparison operators for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'), (operator.gt, '>', '<'), @@ -315,14 +390,16 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A (operator.le, '<=', '>='), (operator.ge, '>=', '<=')): for (lhs, rhs, l_sql, r_sql) in ( - ('a', table1.c.myid, ':mytable_myid', 'mytable.myid'), - ('a', literal('b'), ':literal_1', ':literal'), # note swap! - (table1.c.myid, 'b', 'mytable.myid', ':mytable_myid'), - (table1.c.myid, literal('b'), 'mytable.myid', ':literal'), + ('a', table1.c.myid, ':myid_1', 'mytable.myid'), + ('a', literal('b'), ':param_2', ':param_1'), # note swap! + (table1.c.myid, 'b', 'mytable.myid', ':myid_1'), + (table1.c.myid, literal('b'), 'mytable.myid', ':param_1'), (table1.c.myid, table1.c.myid, 'mytable.myid', 'mytable.myid'), - (literal('a'), 'b', ':literal', ':literal_1'), - (literal('a'), table1.c.myid, ':literal', 'mytable.myid'), - (literal('a'), literal('b'), ':literal', ':literal_1'), + (literal('a'), 'b', ':param_1', ':param_2'), + (literal('a'), table1.c.myid, ':param_1', 'mytable.myid'), + (literal('a'), literal('b'), ':param_1', ':param_2'), + (dt, literal('b'), ':param_2', ':param_1'), + (literal('b'), dt, ':param_1', ':param_2'), ): # the compiled clause should match either (e.g.): @@ -335,95 +412,164 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A "\n'" + compiled + "'\n does not match\n'" + fwd_sql + "'\n or\n'" + rev_sql + "'") - self.runtest( - table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND mytable.name != :mytable_name" + self.assert_compile( + table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :myid_1 AND mytable.name != :name_1" + ) + + self.assert_compile( + table1.select((table1.c.myid != 12) & ~(table1.c.name.between('jack','john'))), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :myid_1 AND "\ + "NOT (mytable.name BETWEEN :name_1 AND :name_2)" ) - self.runtest( - table1.select((table1.c.myid != 12) & ~and_(table1.c.name=='john', table1.c.name=='ed', table1.c.name=='fred')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name AND mytable.name = :mytable_name_1 AND mytable.name = :mytable_name_2)" + self.assert_compile( + table1.select((table1.c.myid != 12) & ~and_(table1.c.name=='john', table1.c.name=='ed', table1.c.name=='fred')), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :myid_1 AND "\ + "NOT (mytable.name = :name_1 AND mytable.name = :name_2 AND mytable.name = :name_3)" ) - self.runtest( - table1.select((table1.c.myid != 12) & ~table1.c.name), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name" + self.assert_compile( + table1.select((table1.c.myid != 12) & ~table1.c.name), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :myid_1 AND NOT mytable.name" ) - self.runtest( - literal("a") + literal("b") * literal("c"), ":literal || :literal_1 * :literal_2" + self.assert_compile( + literal("a") + literal("b") * literal("c"), ":param_1 || :param_2 * :param_3" ) # test the op() function, also that its results are further usable in expressions - self.runtest( + self.assert_compile( table1.select(table1.c.myid.op('hoho')(12)==14), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (mytable.myid hoho :mytable_myid) = :literal" + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (mytable.myid hoho :myid_1) = :param_1" ) - def testunicodestartswith(self): - string = u"hi \xf6 \xf5" - self.runtest( - table1.select(table1.c.name.startswith(string)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name", - checkparams = {'mytable_name': u'hi \xf6 \xf5%'}, + # test that clauses can be pickled (operators need to be module-level, etc.) + clause = (table1.c.myid == 12) & table1.c.myid.between(15, 20) & table1.c.myid.like('hoho') + assert str(clause) == str(util.pickle.loads(util.pickle.dumps(clause))) + + + def test_like(self): + for expr, check, dialect in [ + (table1.c.myid.like('somstr'), "mytable.myid LIKE :myid_1", None), + (~table1.c.myid.like('somstr'), "mytable.myid NOT LIKE :myid_1", None), + (table1.c.myid.like('somstr', escape='\\'), "mytable.myid LIKE :myid_1 ESCAPE '\\'", None), + (~table1.c.myid.like('somstr', escape='\\'), "mytable.myid NOT LIKE :myid_1 ESCAPE '\\'", None), + (table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) LIKE lower(:myid_1) ESCAPE '\\'", None), + (~table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) NOT LIKE lower(:myid_1) ESCAPE '\\'", None), + (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()), + (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()), + (table1.c.name.ilike('%something%'), "lower(mytable.name) LIKE lower(:name_1)", None), + (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgres.PGDialect()), + (~table1.c.name.ilike('%something%'), "lower(mytable.name) NOT LIKE lower(:name_1)", None), + (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgres.PGDialect()), + ]: + self.assert_compile(expr, check, dialect=dialect) + + def test_composed_string_comparators(self): + self.assert_compile( + table1.c.name.contains('jo'), "mytable.name LIKE '%%' || :name_1 || '%%'" , checkparams = {'name_1': u'jo'}, ) - - def testmultiparam(self): - self.runtest( - select(["*"], or_(table1.c.myid == 12, table1.c.myid=='asdf', table1.c.myid == 'foo')), - "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid OR mytable.myid = :mytable_myid_1 OR mytable.myid = :mytable_myid_2" + self.assert_compile( + table1.c.name.contains('jo'), "mytable.name LIKE concat(concat('%%', %s), '%%')" , checkparams = {'name_1': u'jo'}, + dialect=mysql.dialect() + ) + self.assert_compile( + table1.c.name.contains('jo', escape='\\'), "mytable.name LIKE '%%' || :name_1 || '%%' ESCAPE '\\'" , checkparams = {'name_1': u'jo'}, + ) + self.assert_compile( table1.c.name.startswith('jo', escape='\\'), "mytable.name LIKE :name_1 || '%%' ESCAPE '\\'" ) + self.assert_compile( table1.c.name.endswith('jo', escape='\\'), "mytable.name LIKE '%%' || :name_1 ESCAPE '\\'" ) + self.assert_compile( table1.c.name.endswith('hn'), "mytable.name LIKE '%%' || :name_1", checkparams = {'name_1': u'hn'}, ) + self.assert_compile( + table1.c.name.endswith('hn'), "mytable.name LIKE concat('%%', %s)", + checkparams = {'name_1': u'hn'}, dialect=mysql.dialect() + ) + self.assert_compile( + table1.c.name.startswith(u"hi \xf6 \xf5"), "mytable.name LIKE :name_1 || '%%'", + checkparams = {'name_1': u'hi \xf6 \xf5'}, + ) + self.assert_compile(column('name').endswith(text("'foo'")), "name LIKE '%%' || 'foo'" ) + self.assert_compile(column('name').endswith(literal_column("'foo'")), "name LIKE '%%' || 'foo'" ) + self.assert_compile(column('name').startswith(text("'foo'")), "name LIKE 'foo' || '%%'" ) + self.assert_compile(column('name').startswith(text("'foo'")), "name LIKE concat('foo', '%%')", dialect=mysql.dialect()) + self.assert_compile(column('name').startswith(literal_column("'foo'")), "name LIKE 'foo' || '%%'" ) + self.assert_compile(column('name').startswith(literal_column("'foo'")), "name LIKE concat('foo', '%%')", dialect=mysql.dialect()) + + def test_multiple_col_binds(self): + self.assert_compile( + select(["*"], or_(table1.c.myid == 12, table1.c.myid=='asdf', table1.c.myid == 'foo')), + "SELECT * FROM mytable WHERE mytable.myid = :myid_1 OR mytable.myid = :myid_2 OR mytable.myid = :myid_3" ) - def testorderby(self): - self.runtest( + def test_orderby_groupby(self): + self.assert_compile( table2.select(order_by = [table2.c.otherid, asc(table2.c.othername)]), "SELECT myothertable.otherid, myothertable.othername FROM myothertable ORDER BY myothertable.otherid, myothertable.othername ASC" ) - def testgroupby(self): - self.runtest( + + self.assert_compile( + table2.select(order_by = [table2.c.otherid, table2.c.othername.desc()]), + "SELECT myothertable.otherid, myothertable.othername FROM myothertable ORDER BY myothertable.otherid, myothertable.othername DESC" + ) + + # generative order_by + self.assert_compile( + table2.select().order_by(table2.c.otherid).order_by(table2.c.othername.desc()), + "SELECT myothertable.otherid, myothertable.othername FROM myothertable ORDER BY myothertable.otherid, myothertable.othername DESC" + ) + + self.assert_compile( + table2.select().order_by(table2.c.otherid).order_by(table2.c.othername.desc()).order_by(None), + "SELECT myothertable.otherid, myothertable.othername FROM myothertable" + ) + + self.assert_compile( select([table2.c.othername, func.count(table2.c.otherid)], group_by = [table2.c.othername]), - "SELECT myothertable.othername, count(myothertable.otherid) FROM myothertable GROUP BY myothertable.othername" + "SELECT myothertable.othername, count(myothertable.otherid) AS count_1 FROM myothertable GROUP BY myothertable.othername" ) - def testoraclelimit(self): - metadata = MetaData() - users = Table('users', metadata, Column('name', String(10), key='username')) - s = select([users.c.username], limit=5) - self.runtest(s, "SELECT name FROM (SELECT users.name AS name, ROW_NUMBER() OVER (ORDER BY users.rowid) AS ora_rn FROM users) WHERE ora_rn<=5", dialect=oracle.dialect()) - self.runtest(s, "SELECT name FROM (SELECT users.name AS name, ROW_NUMBER() OVER (ORDER BY users.rowid) AS ora_rn FROM users) WHERE ora_rn<=5", dialect=oracle.dialect()) + # generative group by + self.assert_compile( + select([table2.c.othername, func.count(table2.c.otherid)]).group_by(table2.c.othername), + "SELECT myothertable.othername, count(myothertable.otherid) AS count_1 FROM myothertable GROUP BY myothertable.othername" + ) - def testgroupby_and_orderby(self): - self.runtest( + self.assert_compile( + select([table2.c.othername, func.count(table2.c.otherid)]).group_by(table2.c.othername).group_by(None), + "SELECT myothertable.othername, count(myothertable.otherid) AS count_1 FROM myothertable" + ) + + self.assert_compile( select([table2.c.othername, func.count(table2.c.otherid)], group_by = [table2.c.othername], order_by = [table2.c.othername]), - "SELECT myothertable.othername, count(myothertable.otherid) FROM myothertable GROUP BY myothertable.othername ORDER BY myothertable.othername" + "SELECT myothertable.othername, count(myothertable.otherid) AS count_1 FROM myothertable GROUP BY myothertable.othername ORDER BY myothertable.othername" ) - - def testforupdate(self): - self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE") - - self.runtest(table1.select(table1.c.myid==7, for_update="nowait"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE") - self.runtest(table1.select(table1.c.myid==7, for_update="nowait"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE NOWAIT", dialect=oracle.dialect()) + def test_for_update(self): + self.assert_compile(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") - self.runtest(table1.select(table1.c.myid==7, for_update="read"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE", dialect=mysql.dialect()) + self.assert_compile(table1.select(table1.c.myid==7, for_update="nowait"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE") - self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = %s FOR UPDATE", dialect=mysql.dialect()) + self.assert_compile(table1.select(table1.c.myid==7, for_update="nowait"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE NOWAIT", dialect=oracle.dialect()) - self.runtest(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid FOR UPDATE", dialect=oracle.dialect()) - - def testalias(self): + self.assert_compile(table1.select(table1.c.myid==7, for_update="read"), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE", dialect=mysql.dialect()) + + self.assert_compile(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = %s FOR UPDATE", dialect=mysql.dialect()) + + self.assert_compile(table1.select(table1.c.myid==7, for_update=True), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :myid_1 FOR UPDATE", dialect=oracle.dialect()) + + def test_alias(self): # test the alias for a table1. column names stay the same, table name "changes" to "foo". - self.runtest( + self.assert_compile( select([table1.alias('foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo") for dialect in (firebird.dialect(), oracle.dialect()): - self.runtest( + self.assert_compile( select([table1.alias('foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo" ,dialect=dialect) - self.runtest( + self.assert_compile( select([table1.alias()]) ,"SELECT mytable_1.myid, mytable_1.name, mytable_1.description FROM mytable AS mytable_1") @@ -431,47 +577,47 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A # labels tablename_columnname, which become the column keys accessible off the Selectable object. # also, only use one column from the second table and all columns from the first table1. q = select([table1, table2.c.otherid], table1.c.myid == table2.c.otherid, use_labels = True) - + # make an alias of the "selectable". column names stay the same (i.e. the labels), table name "changes" to "t2view". a = alias(q, 't2view') # select from that alias, also using labels. two levels of labels should produce two underscores. # also, reference the column "mytable_myid" off of the t2view alias. - self.runtest( + self.assert_compile( a.select(a.c.mytable_myid == 9, use_labels = True), "SELECT t2view.mytable_myid AS t2view_mytable_myid, t2view.mytable_name AS t2view_mytable_name, \ t2view.mytable_description AS t2view_mytable_description, t2view.myothertable_otherid AS t2view_myothertable_otherid FROM \ (SELECT mytable.myid AS mytable_myid, mytable.name AS mytable_name, mytable.description AS mytable_description, \ myothertable.otherid AS myothertable_otherid FROM mytable, myothertable \ -WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = :t2view_mytable_myid" +WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = :mytable_myid_1" ) - - + + def test_prefixes(self): - self.runtest(table1.select().prefix_with("SQL_CALC_FOUND_ROWS").prefix_with("SQL_SOME_WEIRD_MYSQL_THING"), + self.assert_compile(table1.select().prefix_with("SQL_CALC_FOUND_ROWS").prefix_with("SQL_SOME_WEIRD_MYSQL_THING"), "SELECT SQL_CALC_FOUND_ROWS SQL_SOME_WEIRD_MYSQL_THING mytable.myid, mytable.name, mytable.description FROM mytable" ) - - def testtext(self): - self.runtest( + + def test_text(self): + self.assert_compile( text("select * from foo where lala = bar") , "select * from foo where lala = bar" ) # test bytestring - self.runtest(select( + self.assert_compile(select( ["foobar(a)", "pk_foo_bar(syslaal)"], "a = 12", from_obj = ["foobar left outer join lala on foobar.foo = lala.foo"] - ), + ), "SELECT foobar(a), pk_foo_bar(syslaal) FROM foobar left outer join lala on foobar.foo = lala.foo WHERE a = 12") # test unicode - self.runtest(select( + self.assert_compile(select( [u"foobar(a)", u"pk_foo_bar(syslaal)"], u"a = 12", from_obj = [u"foobar left outer join lala on foobar.foo = lala.foo"] - ), + ), u"SELECT foobar(a), pk_foo_bar(syslaal) FROM foobar left outer join lala on foobar.foo = lala.foo WHERE a = 12") # test building a select query programmatically with text @@ -482,60 +628,64 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = s.append_whereclause("column2=19") s = s.order_by("column1") s.append_from("table1") - self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1") + self.assert_compile(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1") - def testtextcolumns(self): - self.runtest( - select(["column1", "column2"], from_obj=[table1]).alias('somealias').select(), + self.assert_compile( + select(["column1", "column2"], from_obj=table1).alias('somealias').select(), "SELECT somealias.column1, somealias.column2 FROM (SELECT column1, column2 FROM mytable) AS somealias" ) - + # test that use_labels doesnt interfere with literal columns - self.runtest( - select(["column1", "column2", table1.c.myid], from_obj=[table1], use_labels=True), + self.assert_compile( + select(["column1", "column2", table1.c.myid], from_obj=table1, use_labels=True), "SELECT column1, column2, mytable.myid AS mytable_myid FROM mytable" ) # test that use_labels doesnt interfere with literal columns that have textual labels - self.runtest( - select(["column1 AS foobar", "column2 AS hoho", table1.c.myid], from_obj=[table1], use_labels=True), + self.assert_compile( + select(["column1 AS foobar", "column2 AS hoho", table1.c.myid], from_obj=table1, use_labels=True), "SELECT column1 AS foobar, column2 AS hoho, mytable.myid AS mytable_myid FROM mytable" ) - + print "---------------------------------------------" s1 = select(["column1 AS foobar", "column2 AS hoho", table1.c.myid], from_obj=[table1]) print "---------------------------------------------" # test that "auto-labeling of subquery columns" doesnt interfere with literal columns, # exported columns dont get quoted - self.runtest( + self.assert_compile( select(["column1 AS foobar", "column2 AS hoho", table1.c.myid], from_obj=[table1]).select(), "SELECT column1 AS foobar, column2 AS hoho, myid FROM (SELECT column1 AS foobar, column2 AS hoho, mytable.myid AS myid FROM mytable)" ) + + self.assert_compile( + select(['col1','col2'], from_obj='tablename').alias('myalias'), + "SELECT col1, col2 FROM tablename" + ) - def testtextbinds(self): - self.runtest( - text("select * from foo where lala=:bar and hoho=:whee"), - "select * from foo where lala=:bar and hoho=:whee", + def test_binds_in_text(self): + self.assert_compile( + text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar', 4), bindparam('whee', 7)]), + "select * from foo where lala=:bar and hoho=:whee", checkparams={'bar':4, 'whee': 7}, - params={'bar':4, 'whee': 7, 'hoho':10}, ) - self.runtest( - text("select * from foo where clock='05:06:07'"), - "select * from foo where clock='05:06:07'", + self.assert_compile( + text("select * from foo where clock='05:06:07'"), + "select * from foo where clock='05:06:07'", checkparams={}, params={}, ) dialect = postgres.dialect() - self.runtest( - text("select * from foo where lala=:bar and hoho=:whee"), - "select * from foo where lala=%(bar)s and hoho=%(whee)s", + self.assert_compile( + text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]), + "select * from foo where lala=%(bar)s and hoho=%(whee)s", checkparams={'bar':4, 'whee': 7}, - params={'bar':4, 'whee': 7, 'hoho':10}, dialect=dialect ) - self.runtest( + + # test escaping out text() params with a backslash + self.assert_compile( text("select * from foo where clock='05:06:07' and mork='\:mindy'"), "select * from foo where clock='05:06:07' and mork=':mindy'", checkparams={}, @@ -544,224 +694,145 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = ) dialect = sqlite.dialect() - self.runtest( - text("select * from foo where lala=:bar and hoho=:whee"), - "select * from foo where lala=? and hoho=?", - checkparams=[4, 7], - params={'bar':4, 'whee': 7, 'hoho':10}, + self.assert_compile( + text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]), + "select * from foo where lala=? and hoho=?", + checkparams={'bar':4, 'whee':7}, dialect=dialect ) - - def testtextmix(self): - self.runtest(select( + + self.assert_compile(select( [table1, table2.c.otherid, "sysdate()", "foo, bar, lala"], and_( "foo.id = foofoo(lala)", "datetime(foo) = Today", table1.c.myid == table2.c.otherid, ) - ), + ), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, sysdate(), foo, bar, lala \ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today AND mytable.myid = myothertable.otherid") - def testtextualsubquery(self): - self.runtest(select( + self.assert_compile(select( [alias(table1, 't'), "foo.f"], "foo.f = t.id", from_obj = ["(select f from bar where lala=heyhey) foo"] - ), + ), "SELECT t.myid, t.name, t.description, foo.f FROM mytable AS t, (select f from bar where lala=heyhey) foo WHERE foo.f = t.id") - def testliteral(self): - self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]), - "SELECT :literal || :literal_1 FROM mytable") + # test Text embedded within select_from(), using binds + generate_series = text("generate_series(:x, :y, :z) as s(a)", bindparams=[bindparam('x'), bindparam('y'), bindparam('z')]) - def testcalculatedcolumns(self): + s =select([(func.current_date() + literal_column("s.a")).label("dates")]).select_from(generate_series) + self.assert_compile(s, "SELECT CURRENT_DATE + s.a AS dates FROM generate_series(:x, :y, :z) as s(a)", checkparams={'y': None, 'x': None, 'z': None}) + + self.assert_compile(s.params(x=5, y=6, z=7), "SELECT CURRENT_DATE + s.a AS dates FROM generate_series(:x, :y, :z) as s(a)", checkparams={'y': 6, 'x': 5, 'z': 7}) + + + def test_literal(self): + self.assert_compile(select([literal("foo") + literal("bar")], from_obj=[table1]), + "SELECT :param_1 || :param_2 AS anon_1 FROM mytable") + + def test_calculated_columns(self): value_tbl = table('values', column('id', Integer), column('val1', Float), column('val2', Float), ) - self.runtest( + self.assert_compile( select([value_tbl.c.id, (value_tbl.c.val2 - value_tbl.c.val1)/value_tbl.c.val1]), - "SELECT values.id, (values.val2 - values.val1) / values.val1 FROM values" + "SELECT values.id, (values.val2 - values.val1) / values.val1 AS anon_1 FROM values" ) - self.runtest( + self.assert_compile( select([value_tbl.c.id], (value_tbl.c.val2 - value_tbl.c.val1)/value_tbl.c.val1 > 2.0), - "SELECT values.id FROM values WHERE (values.val2 - values.val1) / values.val1 > :literal" + "SELECT values.id FROM values WHERE (values.val2 - values.val1) / values.val1 > :param_1" ) - self.runtest( + self.assert_compile( select([value_tbl.c.id], value_tbl.c.val1 / (value_tbl.c.val2 - value_tbl.c.val1) /value_tbl.c.val1 > 2.0), - "SELECT values.id FROM values WHERE values.val1 / (values.val2 - values.val1) / values.val1 > :literal" + "SELECT values.id FROM values WHERE values.val1 / (values.val2 - values.val1) / values.val1 > :param_1" ) - - def testfunction(self): - """tests the generation of functions using the func keyword""" - # test an expression with a function - self.runtest(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, - "lala(:lala, :lala_1, :literal, mytable.myid) * myothertable.otherid") - - # test it in a SELECT - self.runtest(select([func.count(table1.c.myid)]), - "SELECT count(mytable.myid) FROM mytable") - - # test a "dotted" function name - self.runtest(select([func.foo.bar.lala(table1.c.myid)]), - "SELECT foo.bar.lala(mytable.myid) FROM mytable") - - # test the bind parameter name with a "dotted" function name is only the name - # (limits the length of the bind param name) - self.runtest(select([func.foo.bar.lala(12)]), - "SELECT foo.bar.lala(:lala)") - - # test a dotted func off the engine itself - self.runtest(func.lala.hoho(7), "lala.hoho(:hoho)") - - # test None becomes NULL - self.runtest(func.my_func(1,2,None,3), "my_func(:my_func, :my_func_1, NULL, :my_func_2)") - - def testextract(self): + + def test_extract(self): """test the EXTRACT function""" - self.runtest(select([extract("month", table3.c.otherstuff)]), "SELECT extract(month FROM thirdtable.otherstuff) FROM thirdtable") - - self.runtest(select([extract("day", func.to_date("03/20/2005", "MM/DD/YYYY"))]), "SELECT extract(day FROM to_date(:to_date, :to_date_1))") - - def testjoin(self): - self.runtest( + self.assert_compile(select([extract("month", table3.c.otherstuff)]), "SELECT extract(month FROM thirdtable.otherstuff) AS extract_1 FROM thirdtable") + + self.assert_compile(select([extract("day", func.to_date("03/20/2005", "MM/DD/YYYY"))]), "SELECT extract(day FROM to_date(:to_date_1, :to_date_2)) AS extract_1") + + def test_collate(self): + for expr in (select([table1.c.name.collate('somecol')]), + select([collate(table1.c.name, 'somecol')])): + self.assert_compile( + expr, "SELECT mytable.name COLLATE somecol FROM mytable") + + expr = select([table1.c.name.collate('somecol').like('%x%')]) + self.assert_compile(expr, + "SELECT mytable.name COLLATE somecol " + "LIKE :param_1 AS anon_1 FROM mytable") + + expr = select([table1.c.name.like(collate('%x%', 'somecol'))]) + self.assert_compile(expr, + "SELECT mytable.name " + "LIKE :param_1 COLLATE somecol AS anon_1 " + "FROM mytable") + + expr = select([table1.c.name.collate('col1').like( + collate('%x%', 'col2'))]) + self.assert_compile(expr, + "SELECT mytable.name COLLATE col1 " + "LIKE :param_1 COLLATE col2 AS anon_1 " + "FROM mytable") + + expr = select([func.concat('a', 'b').collate('somecol').label('x')]) + self.assert_compile(expr, + "SELECT concat(:param_1, :param_2) " + "COLLATE somecol AS x") + + def test_joins(self): + self.assert_compile( join(table2, table1, table1.c.myid == table2.c.otherid).select(), "SELECT myothertable.otherid, myothertable.othername, mytable.myid, mytable.name, \ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertable.otherid" ) - self.runtest( + self.assert_compile( select( [table1], from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid)] ), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid") - self.runtest( + self.assert_compile( select( [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid) ]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid" ) - - self.runtest( + + self.assert_compile( join(users, addresses, users.c.user_id==addresses.c.user_id).select(), "SELECT users.user_id, users.user_name, users.password, addresses.address_id, addresses.user_id, addresses.street, addresses.city, addresses.state, addresses.zip FROM users JOIN addresses ON users.user_id = addresses.user_id" ) - - def testmultijoin(self): - self.runtest( + + self.assert_compile( select([table1, table2, table3], - + from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid).outerjoin(table3, table1.c.myid==table3.c.userid)] - + #from_obj = [outerjoin(join(table, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid==table3.c.userid)] ) ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON mytable.myid = thirdtable.userid" ) - self.runtest( + self.assert_compile( select([table1, table2, table3], from_obj = [outerjoin(table1, join(table2, table3, table2.c.otherid == table3.c.userid), table1.c.myid==table2.c.otherid)] ) ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN (myothertable JOIN thirdtable ON myothertable.otherid = thirdtable.userid) ON mytable.myid = myothertable.otherid" ) - - def testunion(self): - x = union( - select([table1], table1.c.myid == 5), - select([table1], table1.c.myid == 12), - order_by = [table1.c.myid], - ) - - self.runtest(x, "SELECT mytable.myid, mytable.name, mytable.description \ -FROM mytable WHERE mytable.myid = :mytable_myid UNION \ -SELECT mytable.myid, mytable.name, mytable.description \ -FROM mytable WHERE mytable.myid = :mytable_myid_1 ORDER BY mytable.myid") - - self.runtest( - union( - select([table1]), - select([table2]), - select([table3]) - ) - , - "SELECT mytable.myid, mytable.name, mytable.description \ -FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \ -FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable") - - u = union( - select([table1]), - select([table2]), - select([table3]) - ) - assert u.corresponding_column(table2.c.otherid) is u.c.otherid - - self.runtest( - union( - select([table1]), - select([table2]), - order_by=['myid'], - offset=10, - limit=5 - ) - , "SELECT mytable.myid, mytable.name, mytable.description \ -FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \ -FROM myothertable ORDER BY myid \ - LIMIT 5 OFFSET 10" - ) - - self.runtest( - union( - select([table1.c.myid, table1.c.name, func.max(table1.c.description)], table1.c.name=='name2', group_by=[table1.c.myid, table1.c.name]), - table1.select(table1.c.name=='name1') - ) - , - "SELECT mytable.myid, mytable.name, max(mytable.description) FROM mytable \ -WHERE mytable.name = :mytable_name GROUP BY mytable.myid, mytable.name UNION SELECT mytable.myid, mytable.name, mytable.description \ -FROM mytable WHERE mytable.name = :mytable_name_1" - ) - def test_compound_select_grouping(self): - self.runtest( - union_all( - select([table1.c.myid]), - union( - select([table2.c.otherid]), - select([table3.c.userid]), - ) - ) - , - "SELECT mytable.myid FROM mytable UNION ALL (SELECT myothertable.otherid FROM myothertable UNION \ -SELECT thirdtable.userid FROM thirdtable)" - ) - # This doesn't need grouping, so don't group to not give sqlite unnecessarily hard time - self.runtest( - union( - except_( - select([table2.c.otherid]), - select([table3.c.userid]), - ), - select([table1.c.myid]) - ) - , - "SELECT myothertable.otherid FROM myothertable EXCEPT SELECT thirdtable.userid FROM thirdtable \ -UNION SELECT mytable.myid FROM mytable" - ) - - def testouterjoin(self): - # test an outer join. the oracle module should take the ON clause of the join and - # move it up to the WHERE clause of its parent select, and append (+) to all right-hand-side columns - # within the original onclause, but leave right-hand-side columns unchanged outside of the onclause - # parameters. - query = select( [table1, table2], or_( @@ -772,35 +843,121 @@ UNION SELECT mytable.myid FROM mytable" ), from_obj = [ outerjoin(table1, table2, table1.c.myid == table2.c.otherid) ] ) - self.runtest(query, + self.assert_compile(query, "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid \ -WHERE mytable.name = %(mytable_name)s OR mytable.myid = %(mytable_myid)s OR \ -myothertable.othername != %(myothertable_othername)s OR \ +WHERE mytable.name = :name_1 OR mytable.myid = :myid_1 OR \ +myothertable.othername != :othername_1 OR \ EXISTS (select yay from foo where boo = lar)", - dialect=postgres.dialect() ) + def test_compound_selects(self): + try: + union(table3.select(), table1.select()) + except exceptions.ArgumentError, err: + assert str(err) == "All selectables passed to CompoundSelect must have identical numbers of columns; select #1 has 2 columns, select #2 has 3" + + x = union( + select([table1], table1.c.myid == 5), + select([table1], table1.c.myid == 12), + order_by = [table1.c.myid], + ) - self.runtest(query, - "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ -FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \ -(mytable.name = :mytable_name OR mytable.myid = :mytable_myid OR \ -myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo where boo = lar))", - dialect=oracle.OracleDialect(use_ansi = False)) + self.assert_compile(x, "SELECT mytable.myid, mytable.name, mytable.description \ +FROM mytable WHERE mytable.myid = :myid_1 UNION \ +SELECT mytable.myid, mytable.name, mytable.description \ +FROM mytable WHERE mytable.myid = :myid_2 ORDER BY mytable.myid") - query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid) - self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON thirdtable.userid = myothertable.otherid") - self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False)) + u1 = union( + select([table1.c.myid, table1.c.name]), + select([table2]), + select([table3]) + ) + self.assert_compile(u1, + "SELECT mytable.myid, mytable.name \ +FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \ +FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable") + + assert u1.corresponding_column(table2.c.otherid) is u1.c.myid + + assert u1.corresponding_column(table1.oid_column) is u1.oid_column + assert u1.corresponding_column(table2.oid_column) is u1.oid_column - def testbindparam(self): + # TODO - why is there an extra space before the LIMIT ? + self.assert_compile( + union( + select([table1.c.myid, table1.c.name]), + select([table2]), + order_by=['myid'], + offset=10, + limit=5 + ) + , "SELECT mytable.myid, mytable.name \ +FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \ +FROM myothertable ORDER BY myid LIMIT 5 OFFSET 10" + ) + + self.assert_compile( + union( + select([table1.c.myid, table1.c.name, func.max(table1.c.description)], table1.c.name=='name2', group_by=[table1.c.myid, table1.c.name]), + table1.select(table1.c.name=='name1') + ) + , + "SELECT mytable.myid, mytable.name, max(mytable.description) AS max_1 FROM mytable \ +WHERE mytable.name = :name_1 GROUP BY mytable.myid, mytable.name UNION SELECT mytable.myid, mytable.name, mytable.description \ +FROM mytable WHERE mytable.name = :name_2" + ) + + self.assert_compile( + union( + select([literal(100).label('value')]), + select([literal(200).label('value')]) + ), + "SELECT :param_1 AS value UNION SELECT :param_2 AS value" + ) + + self.assert_compile( + union_all( + select([table1.c.myid]), + union( + select([table2.c.otherid]), + select([table3.c.userid]), + ) + ) + , + "SELECT mytable.myid FROM mytable UNION ALL (SELECT myothertable.otherid FROM myothertable UNION \ +SELECT thirdtable.userid FROM thirdtable)" + ) + # This doesn't need grouping, so don't group to not give sqlite unnecessarily hard time + self.assert_compile( + union( + except_( + select([table2.c.otherid]), + select([table3.c.userid]), + ), + select([table1.c.myid]) + ) + , + "SELECT myothertable.otherid FROM myothertable EXCEPT SELECT thirdtable.userid FROM thirdtable \ +UNION SELECT mytable.myid FROM mytable" + ) + + # test unions working with non-oid selectables + s = select([column('foo'), column('bar')]) + s = union(s, s) + s = union(s, s) + self.assert_compile(s, "SELECT foo, bar UNION SELECT foo, bar UNION (SELECT foo, bar UNION SELECT foo, bar)") + + + @testing.uses_deprecated('//get_params') + def test_binds(self): for ( stmt, expected_named_stmt, expected_positional_stmt, - expected_default_params_dict, + expected_default_params_dict, expected_default_params_list, - test_param_dict, + test_param_dict, expected_test_params_dict, expected_test_params_list ) in [ @@ -832,148 +989,192 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo ), ( select([table1], or_(table1.c.myid==bindparam('myid', unique=True), table2.c.otherid==bindparam('myid', unique=True))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid_1", + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid_1 OR myothertable.otherid = :myid_2", "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?", - {'myid':None, 'myid_1':None}, [None, None], - {'myid':5, 'myid_1': 6}, {'myid':5, 'myid_1':6}, [5,6] + {'myid_1':None, 'myid_2':None}, [None, None], + {'myid_1':5, 'myid_2': 6}, {'myid_1':5, 'myid_2':6}, [5,6] + ), + ( + bindparam('test', type_=String) + text("'hi'"), + ":test || 'hi'", + "? || 'hi'", + {'test':None}, [None], + {}, {'test':None}, [None] + ), + ( + select([table1], or_(table1.c.myid==bindparam('myid'), table2.c.otherid==bindparam('myotherid'))).params({'myid':8, 'myotherid':7}), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myotherid", + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?", + {'myid':8, 'myotherid':7}, [8, 7], + {'myid':5}, {'myid':5, 'myotherid':7}, [5,7] ), ( select([table1], or_(table1.c.myid==bindparam('myid', value=7, unique=True), table2.c.otherid==bindparam('myid', value=8, unique=True))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid OR myothertable.otherid = :myid_1", + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :myid_1 OR myothertable.otherid = :myid_2", "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = ? OR myothertable.otherid = ?", - {'myid':7, 'myid_1':8}, [7,8], - {'myid':5, 'myid_1':6}, {'myid':5, 'myid_1':6}, [5,6] + {'myid_1':7, 'myid_2':8}, [7,8], + {'myid_1':5, 'myid_2':6}, {'myid_1':5, 'myid_2':6}, [5,6] ), - ][2:3]: - - self.runtest(stmt, expected_named_stmt, params=expected_default_params_dict) - self.runtest(stmt, expected_positional_stmt, dialect=sqlite.dialect()) + ]: + + self.assert_compile(stmt, expected_named_stmt, params=expected_default_params_dict) + self.assert_compile(stmt, expected_positional_stmt, dialect=sqlite.dialect()) nonpositional = stmt.compile() positional = stmt.compile(dialect=sqlite.dialect()) - assert positional.get_params().get_raw_list() == expected_default_params_list - assert nonpositional.get_params(**test_param_dict).get_raw_dict() == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict).get_raw_dict())) - assert positional.get_params(**test_param_dict).get_raw_list() == expected_test_params_list - + pp = positional.get_params() + assert [pp[k] for k in positional.positiontup] == expected_default_params_list + assert nonpositional.get_params(**test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict))) + pp = positional.get_params(**test_param_dict) + assert [pp[k] for k in positional.positiontup] == expected_test_params_list + + # check that params() doesnt modify original statement + s = select([table1], or_(table1.c.myid==bindparam('myid'), table2.c.otherid==bindparam('myotherid'))) + s2 = s.params({'myid':8, 'myotherid':7}) + s3 = s2.params({'myid':9}) + assert s.compile().params == {'myid':None, 'myotherid':None} + assert s2.compile().params == {'myid':8, 'myotherid':7} + assert s3.compile().params == {'myid':9, 'myotherid':7} + + # test using same 'unique' param object twice in one compile + s = select([table1.c.myid]).where(table1.c.myid==12).as_scalar() + s2 = select([table1, s], table1.c.myid==s) + self.assert_compile(s2, + "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable WHERE mytable.myid = "\ + ":myid_1) AS anon_1 FROM mytable WHERE mytable.myid = (SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)") + positional = s2.compile(dialect=sqlite.dialect()) + + pp = positional.get_params() + assert [pp[k] for k in positional.positiontup] == [12, 12] + # check that conflicts with "unique" params are caught - s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('mytable_myid'))) - try: - str(s) - assert False - except exceptions.CompileError, err: - assert str(err) == "Bind parameter 'mytable_myid' conflicts with unique bind parameter of the same name" + s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('myid_1'))) + self.assertRaisesMessage(exceptions.CompileError, "conflicts with unique bind parameter of the same name", str, s) + + s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('myid_1'))) + self.assertRaisesMessage(exceptions.CompileError, "conflicts with unique bind parameter of the same name", str, s) + - s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('mytable_myid_1'))) - try: - str(s) - assert False - except exceptions.CompileError, err: - assert str(err) == "Bind parameter 'mytable_myid_1' conflicts with unique bind parameter of the same name" - - # check that the bind params sent along with a compile() call - # get preserved when the params are retreived later - s = select([table1], table1.c.myid == bindparam('test')) - c = s.compile(parameters = {'test' : 7}) - self.assert_(c.get_params().get_original_dict() == {'test' : 7}) - def testbindascol(self): + def test_bind_as_col(self): t = table('foo', column('id')) s = select([t, literal('lala').label('hoho')]) - self.runtest(s, "SELECT foo.id, :literal AS hoho FROM foo") + self.assert_compile(s, "SELECT foo.id, :param_1 AS hoho FROM foo") assert [str(c) for c in s.c] == ["id", "hoho"] - - def testin(self): - self.runtest(select([table1], table1.c.myid.in_('a')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid") - self.runtest(select([table1], table1.c.myid.in_('a', 'b')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_myid_1)") + def test_in(self): + self.assert_compile(select([table1], table1.c.myid.in_(['a'])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)") + + self.assert_compile(select([table1], ~table1.c.myid.in_(['a'])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid NOT IN (:myid_1)") - self.runtest(select([table1], table1.c.myid.in_(literal('a'))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal") + self.assert_compile(select([table1], table1.c.myid.in_(['a', 'b'])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :myid_2)") - self.runtest(select([table1], table1.c.myid.in_(literal('a'), 'b')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, :mytable_myid)") + self.assert_compile(select([table1], table1.c.myid.in_(iter(['a', 'b']))), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :myid_2)") - self.runtest(select([table1], table1.c.myid.in_(literal('a'), literal('b'))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, :literal_1)") + self.assert_compile(select([table1], table1.c.myid.in_([literal('a')])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1)") - self.runtest(select([table1], table1.c.myid.in_('a', literal('b'))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal)") + self.assert_compile(select([table1], table1.c.myid.in_([literal('a'), 'b'])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1, :myid_1)") - self.runtest(select([table1], table1.c.myid.in_(literal(1) + 'a')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal + :literal_1") + self.assert_compile(select([table1], table1.c.myid.in_([literal('a'), literal('b')])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1, :param_2)") - self.runtest(select([table1], table1.c.myid.in_(literal('a') +'a', 'b')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :mytable_myid)") + self.assert_compile(select([table1], table1.c.myid.in_(['a', literal('b')])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :param_1)") - self.runtest(select([table1], table1.c.myid.in_(literal('a') + literal('a'), literal('b'))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :literal_2)") + self.assert_compile(select([table1], table1.c.myid.in_([literal(1) + 'a'])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1 + :param_2)") - self.runtest(select([table1], table1.c.myid.in_(1, literal(3) + 4)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal + :literal_1)") + self.assert_compile(select([table1], table1.c.myid.in_([literal('a') +'a', 'b'])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1 || :param_2, :myid_1)") - self.runtest(select([table1], table1.c.myid.in_(literal('a') < 'b')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (:literal < :literal_1)") + self.assert_compile(select([table1], table1.c.myid.in_([literal('a') + literal('a'), literal('b')])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1 || :param_2, :param_3)") - self.runtest(select([table1], table1.c.myid.in_(table1.c.myid)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = mytable.myid") + self.assert_compile(select([table1], table1.c.myid.in_([1, literal(3) + 4])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :param_1 + :param_2)") - self.runtest(select([table1], table1.c.myid.in_('a', table1.c.myid)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, mytable.myid)") + self.assert_compile(select([table1], table1.c.myid.in_([literal('a') < 'b'])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1 < :param_2)") - self.runtest(select([table1], table1.c.myid.in_(literal('a'), table1.c.myid)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, mytable.myid)") + self.assert_compile(select([table1], table1.c.myid.in_([table1.c.myid])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (mytable.myid)") - self.runtest(select([table1], table1.c.myid.in_(literal('a'), table1.c.myid +'a')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, mytable.myid + :mytable_myid)") + self.assert_compile(select([table1], table1.c.myid.in_(['a', table1.c.myid])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, mytable.myid)") - self.runtest(select([table1], table1.c.myid.in_(literal(1), 'a' + table1.c.myid)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, :mytable_myid + mytable.myid)") + self.assert_compile(select([table1], table1.c.myid.in_([literal('a'), table1.c.myid])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1, mytable.myid)") - self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_myid_1, :mytable_myid_2)") + self.assert_compile(select([table1], table1.c.myid.in_([literal('a'), table1.c.myid +'a'])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1, mytable.myid + :myid_1)") - self.runtest(select([table1], table1.c.myid.in_(select([table2.c.otherid]))), + self.assert_compile(select([table1], table1.c.myid.in_([literal(1), 'a' + table1.c.myid])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:param_1, :myid_1 + mytable.myid)") + + self.assert_compile(select([table1], table1.c.myid.in_([1, 2, 3])), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :myid_2, :myid_3)") + + self.assert_compile(select([table1], table1.c.myid.in_(select([table2.c.otherid]))), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid FROM myothertable)") - self.runtest(select([table1], ~table1.c.myid.in_(select([table2.c.otherid]))), + self.assert_compile(select([table1], ~table1.c.myid.in_(select([table2.c.otherid]))), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid NOT IN (SELECT myothertable.otherid FROM myothertable)") - self.runtest(select([table1], table1.c.myid.in_( + self.assert_compile(select([table1], table1.c.myid.in_( union( select([table1], table1.c.myid == 5), select([table1], table1.c.myid == 12), ) )), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable \ WHERE mytable.myid IN (\ -SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid \ -UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid_1)") - +SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :myid_1 \ +UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :myid_2)") + # test that putting a select in an IN clause does not blow away its ORDER BY clause - self.runtest( - select([table1, table2], + self.assert_compile( + select([table1, table2], table2.c.otherid.in_( select([table2.c.otherid], order_by=[table2.c.othername], limit=10, correlate=False) ), from_obj=[table1.join(table2, table1.c.myid==table2.c.otherid)], order_by=[table1.c.myid] ), - "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid WHERE myothertable.otherid IN (SELECT myothertable.otherid FROM myothertable ORDER BY myothertable.othername LIMIT 10) ORDER BY mytable.myid" + "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable "\ + "JOIN myothertable ON mytable.myid = myothertable.otherid WHERE myothertable.otherid IN (SELECT myothertable.otherid "\ + "FROM myothertable ORDER BY myothertable.othername LIMIT 10) ORDER BY mytable.myid" ) - + # test empty in clause - self.runtest(select([table1], table1.c.myid.in_()), + self.assert_compile(select([table1], table1.c.myid.in_([])), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)") - - - def testcast(self): + + @testing.uses_deprecated('passing in_') + def test_in_deprecated_api(self): + self.assert_compile(select([table1], table1.c.myid.in_('abc')), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)") + + self.assert_compile(select([table1], table1.c.myid.in_(1)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)") + + self.assert_compile(select([table1], table1.c.myid.in_(1,2)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :myid_2)") + + self.assert_compile(select([table1], table1.c.myid.in_()), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)") + + def test_cast(self): tbl = table('casttest', column('id', Integer), column('v1', Float), column('v2', Float), column('ts', TIMESTAMP), ) - + def check_results(dialect, expected_results, literal): self.assertEqual(len(expected_results), 5, 'Incorrect number of expected results') self.assertEqual(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) @@ -981,160 +1182,261 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2]) self.assertEqual(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) self.assertEqual(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4])) - sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect) - self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) \nFROM casttest") + # fixme: shoving all of this dialect-specific stuff in one test + # is now officialy completely ridiculous AND non-obviously omits + # coverage on other dialects. + sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect) + if isinstance(dialect, type(mysql.dialect())): + self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest") + else: + self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest") + # first test with Postgres engine - check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s') + check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s') # then the Oracle engine - check_results(oracle.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal') + check_results(oracle.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1') # then the sqlite engine check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') - # MySQL seems to only support DATE types for cast - self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=mysql.dialect())), 'CAST(casttest.ts AS DATE)') - self.assertEqual(str(cast(tbl.c.ts, Numeric).compile(dialect=mysql.dialect())), 'casttest.ts') + # then the MySQL engine + check_results(mysql.dialect(), ['DECIMAL(10, 2)', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s') - def testdatebetween(self): + self.assert_compile(cast(text('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) + self.assert_compile(cast(null(), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) + self.assert_compile(cast(literal_column('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) + + def test_date_between(self): import datetime - table = Table('dt', metadata, + table = Table('dt', metadata, Column('date', Date)) - self.runtest(table.select(table.c.date.between(datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :dt_date AND :dt_date_1", checkparams={'dt_date':datetime.date(2006,6,1), 'dt_date_1':datetime.date(2006,6,5)}) + self.assert_compile(table.select(table.c.date.between(datetime.date(2006,6,1), datetime.date(2006,6,5))), + "SELECT dt.date FROM dt WHERE dt.date BETWEEN :date_1 AND :date_2", checkparams={'date_1':datetime.date(2006,6,1), 'date_2':datetime.date(2006,6,5)}) + + self.assert_compile(table.select(sql.between(table.c.date, datetime.date(2006,6,1), datetime.date(2006,6,5))), + "SELECT dt.date FROM dt WHERE dt.date BETWEEN :param_1 AND :param_2", checkparams={'param_1':datetime.date(2006,6,1), 'param_2':datetime.date(2006,6,5)}) - self.runtest(table.select(sql.between(table.c.date, datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :literal AND :literal_1", checkparams={'literal':datetime.date(2006,6,1), 'literal_1':datetime.date(2006,6,5)}) - def test_operator_precedence(self): table = Table('op', metadata, Column('field', Integer)) - self.runtest(table.select((table.c.field == 5) == None), - "SELECT op.field FROM op WHERE (op.field = :op_field) IS NULL") - self.runtest(table.select((table.c.field + 5) == table.c.field), - "SELECT op.field FROM op WHERE op.field + :op_field = op.field") - self.runtest(table.select((table.c.field + 5) * 6), - "SELECT op.field FROM op WHERE (op.field + :op_field) * :literal") - self.runtest(table.select((table.c.field * 5) + 6), - "SELECT op.field FROM op WHERE op.field * :op_field + :literal") - self.runtest(table.select(5 + table.c.field.in_(5,6)), - "SELECT op.field FROM op WHERE :literal + (op.field IN (:op_field, :op_field_1))") - self.runtest(table.select((5 + table.c.field).in_(5,6)), - "SELECT op.field FROM op WHERE :op_field + op.field IN (:literal, :literal_1)") - self.runtest(table.select(not_(and_(table.c.field == 5, table.c.field == 7))), - "SELECT op.field FROM op WHERE NOT (op.field = :op_field AND op.field = :op_field_1)") - self.runtest(table.select(not_(table.c.field) == 5), - "SELECT op.field FROM op WHERE (NOT op.field) = :literal") - self.runtest(table.select((table.c.field == table.c.field).between(False, True)), - "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :literal AND :literal_1") - self.runtest(table.select(between((table.c.field == table.c.field), False, True)), - "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :literal AND :literal_1") - -class CRUDTest(SQLTest): - def testinsert(self): + self.assert_compile(table.select((table.c.field == 5) == None), + "SELECT op.field FROM op WHERE (op.field = :field_1) IS NULL") + self.assert_compile(table.select((table.c.field + 5) == table.c.field), + "SELECT op.field FROM op WHERE op.field + :field_1 = op.field") + self.assert_compile(table.select((table.c.field + 5) * 6), + "SELECT op.field FROM op WHERE (op.field + :field_1) * :param_1") + self.assert_compile(table.select((table.c.field * 5) + 6), + "SELECT op.field FROM op WHERE op.field * :field_1 + :param_1") + self.assert_compile(table.select(5 + table.c.field.in_([5,6])), + "SELECT op.field FROM op WHERE :param_1 + (op.field IN (:field_1, :field_2))") + self.assert_compile(table.select((5 + table.c.field).in_([5,6])), + "SELECT op.field FROM op WHERE :field_1 + op.field IN (:param_1, :param_2)") + self.assert_compile(table.select(not_(and_(table.c.field == 5, table.c.field == 7))), + "SELECT op.field FROM op WHERE NOT (op.field = :field_1 AND op.field = :field_2)") + self.assert_compile(table.select(not_(table.c.field == 5)), + "SELECT op.field FROM op WHERE op.field != :field_1") + self.assert_compile(table.select(not_(table.c.field.between(5, 6))), + "SELECT op.field FROM op WHERE NOT (op.field BETWEEN :field_1 AND :field_2)") + self.assert_compile(table.select(not_(table.c.field) == 5), + "SELECT op.field FROM op WHERE (NOT op.field) = :param_1") + self.assert_compile(table.select((table.c.field == table.c.field).between(False, True)), + "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :param_1 AND :param_2") + self.assert_compile(table.select(between((table.c.field == table.c.field), False, True)), + "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :param_1 AND :param_2") + + def test_naming(self): + s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')]) + assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg'] + + from sqlalchemy.databases.sqlite import SLNumeric + meta = MetaData() + t1 = Table('mytable', meta, Column('col1', Integer)) + + for col, key, expr, label in ( + (table1.c.name, 'name', 'mytable.name', None), + (table1.c.myid==12, 'mytable.myid = :myid_1', 'mytable.myid = :myid_1', 'anon_1'), + (func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'), + (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'), + (t1.c.col1, 'col1', 'mytable.col1', None), + (column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '') + ): + s1 = select([col], from_obj=getattr(col, 'table', None) or table1) + assert s1.c.keys() == [key], s1.c.keys() + + if label: + self.assert_compile(s1, "SELECT %s AS %s FROM mytable" % (expr, label)) + else: + self.assert_compile(s1, "SELECT %s FROM mytable" % (expr,)) + + s1 = select([s1]) + if label: + self.assert_compile(s1, "SELECT %s FROM (SELECT %s AS %s FROM mytable)" % (label, expr, label)) + elif col.table is not None: + # sqlite rule labels subquery columns + self.assert_compile(s1, "SELECT %s FROM (SELECT %s AS %s FROM mytable)" % (key,expr, key)) + else: + self.assert_compile(s1, "SELECT %s FROM (SELECT %s FROM mytable)" % (expr,expr)) + +class CRUDTest(TestBase, AssertsCompiledSQL): + def test_insert(self): # generic insert, will create bind params for all columns - self.runtest(insert(table1), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)") + self.assert_compile(insert(table1), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)") # insert with user-supplied bind params for specific columns, # cols provided literally - self.runtest( - insert(table1, {table1.c.myid : bindparam('userid'), table1.c.name : bindparam('username')}), + self.assert_compile( + insert(table1, {table1.c.myid : bindparam('userid'), table1.c.name : bindparam('username')}), "INSERT INTO mytable (myid, name) VALUES (:userid, :username)") - + # insert with user-supplied bind params for specific columns, cols # provided as strings - self.runtest( - insert(table1, dict(myid = 3, name = 'jack')), + self.assert_compile( + insert(table1, dict(myid = 3, name = 'jack')), "INSERT INTO mytable (myid, name) VALUES (:myid, :name)" ) # test with a tuple of params instead of named - self.runtest( - insert(table1, (3, 'jack', 'mydescription')), + self.assert_compile( + insert(table1, (3, 'jack', 'mydescription')), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)", checkparams = {'myid':3, 'name':'jack', 'description':'mydescription'} ) - - - def testinsertexpression(self): - self.runtest(insert(table1), "INSERT INTO mytable (myid) VALUES (lala())", params=dict(myid=func.lala())) - - def testupdate(self): - self.runtest(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table1.c.name:'fred'}) - self.runtest(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'}) - self.runtest(update(table1, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid") - self.runtest(update(table1, whereclause = table1.c.name == bindparam('crit'), values = {table1.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'}) - self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'}) - self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'}) + + self.assert_compile( + insert(table1, values={table1.c.myid : bindparam('userid')}).values({table1.c.name : bindparam('username')}), + "INSERT INTO mytable (myid, name) VALUES (:userid, :username)" + ) + + self.assert_compile(insert(table1, values=dict(myid=func.lala())), "INSERT INTO mytable (myid) VALUES (lala())") + + def test_inline_insert(self): + metadata = MetaData() + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('foo', Integer, default=func.foobar())) + self.assert_compile(table.insert(values={}, inline=True), "INSERT INTO sometable (foo) VALUES (foobar())") + self.assert_compile(table.insert(inline=True), "INSERT INTO sometable (foo) VALUES (foobar())", params={}) + + def test_update(self): + self.assert_compile(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", params = {table1.c.name:'fred'}) + self.assert_compile(table1.update().where(table1.c.myid==7).values({table1.c.myid:5}), "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", checkparams={'myid':5, 'myid_1':7}) + self.assert_compile(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", params = {'name':'fred'}) + self.assert_compile(update(table1, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid") + self.assert_compile(update(table1, whereclause = table1.c.name == bindparam('crit'), values = {table1.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'}, checkparams={'crit':'notthere', 'name':'hi'}) + self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :myid_1", params = {'description':'test'}, checkparams={'description':'test', 'myid_1':12}) + self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :myid_1", params = {'myid_1': 12, 'myid': 9, 'description': 'test'}) + self.assert_compile(update(table1, table1.c.myid ==12), "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", params={'myid':18}, checkparams={'myid':18, 'myid_1':12}) s = table1.update(table1.c.myid == 12, values = {table1.c.name : 'lala'}) - c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}) + c = s.compile(column_keys=['id', 'name']) + self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}).values({table1.c.name:table1.c.name + 'foo'}), "UPDATE mytable SET name=(mytable.name || :name_1), description=:description WHERE mytable.myid = :myid_1", params = {'description':'test'}) self.assert_(str(s) == str(c)) - - def testupdateexpression(self): - self.runtest(update(table1, + + self.assert_compile(update(table1, (table1.c.myid == func.hoho(4)) & (table1.c.name == literal('foo') + table1.c.name + literal('lala')), values = { table1.c.name : table1.c.name + "lala", table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho')) - }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal), name=(mytable.name || :mytable_name) " - "WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal_1 || mytable.name || :literal_2") - - def testcorrelatedupdate(self): + }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :param_1), name=(mytable.name || :name_1) " + "WHERE mytable.myid = hoho(:hoho_1) AND mytable.name = :param_2 || mytable.name || :param_3") + + def test_correlated_update(self): # test against a straight text subquery u = update(table1, values = {table1.c.name : text("(select name from mytable where id=mytable.id)")}) - self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)") + self.assert_compile(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)") mt = table1.alias() u = update(table1, values = {table1.c.name : select([mt.c.name], mt.c.myid==table1.c.myid)}) - self.runtest(u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)") - + self.assert_compile(u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)") + # test against a regular constructed subquery s = select([table2], table2.c.otherid == table1.c.myid) u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s}) - self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name") + self.assert_compile(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :name_1") # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) u = update(table1, table1.c.name==s) - self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)") + self.assert_compile(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = "\ + "(SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :otherid_1)") # test one that is actually correlated... s = select([table2.c.othername], table2.c.otherid == table1.c.myid) u = table1.update(table1.c.name==s) - self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") + self.assert_compile(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = "\ + "(SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") + + def test_delete(self): + self.assert_compile(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :myid_1") + self.assert_compile(table1.delete().where(table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :myid_1") + self.assert_compile(table1.delete().where(table1.c.myid == 7).where(table1.c.name=='somename'), "DELETE FROM mytable WHERE mytable.myid = :myid_1 AND mytable.name = :name_1") - def testdelete(self): - self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") - - def testcorrelateddelete(self): + def test_correlated_delete(self): # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) u = delete(table1, table1.c.name==s) - self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)") + self.assert_compile(u, "DELETE FROM mytable WHERE mytable.name = "\ + "(SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :otherid_1)") # test one that is actually correlated... s = select([table2.c.othername], table2.c.otherid == table1.c.myid) u = table1.delete(table1.c.name==s) - self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") - -class SchemaTest(SQLTest): - def testselect(self): + self.assert_compile(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") + +class InlineDefaultTest(TestBase, AssertsCompiledSQL): + def test_insert(self): + m = MetaData() + foo = Table('foo', m, + Column('id', Integer)) + + t = Table('test', m, + Column('col1', Integer, default=func.foo(1)), + Column('col2', Integer, default=select([func.coalesce(func.max(foo.c.id))])), + ) + + self.assert_compile(t.insert(inline=True, values={}), "INSERT INTO test (col1, col2) VALUES (foo(:foo_1), (SELECT coalesce(max(foo.id)) AS coalesce_1 FROM foo))") + + def test_update(self): + m = MetaData() + foo = Table('foo', m, + Column('id', Integer)) + + t = Table('test', m, + Column('col1', Integer, onupdate=func.foo(1)), + Column('col2', Integer, onupdate=select([func.coalesce(func.max(foo.c.id))])), + Column('col3', String(30)) + ) + + self.assert_compile(t.update(inline=True, values={'col3':'foo'}), "UPDATE test SET col1=foo(:foo_1), col2=(SELECT coalesce(max(foo.id)) AS coalesce_1 FROM foo), col3=:col3") + +class SchemaTest(TestBase, AssertsCompiledSQL): + @testing.fails_on('mssql') + def test_select(self): # these tests will fail with the MS-SQL compiler since it will alias schema-qualified tables - self.runtest(table4.select(), "SELECT remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM remote_owner.remotetable") - self.runtest(table4.select(and_(table4.c.datatype_id==7, table4.c.value=='hi')), "SELECT remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM remote_owner.remotetable WHERE remotetable.datatype_id = :remotetable_datatype_id AND remotetable.value = :remotetable_value") + self.assert_compile(table4.select(), "SELECT remote_owner.remotetable.rem_id, remote_owner.remotetable.datatype_id, remote_owner.remotetable.value FROM remote_owner.remotetable") + self.assert_compile(table4.select(and_(table4.c.datatype_id==7, table4.c.value=='hi')), + "SELECT remote_owner.remotetable.rem_id, remote_owner.remotetable.datatype_id, remote_owner.remotetable.value FROM remote_owner.remotetable WHERE "\ + "remote_owner.remotetable.datatype_id = :datatype_id_1 AND remote_owner.remotetable.value = :value_1") s = table4.select(and_(table4.c.datatype_id==7, table4.c.value=='hi')) s.use_labels = True - self.runtest(s, "SELECT remotetable.rem_id AS remotetable_rem_id, remotetable.datatype_id AS remotetable_datatype_id, remotetable.value AS remotetable_value FROM remote_owner.remotetable WHERE remotetable.datatype_id = :remotetable_datatype_id AND remotetable.value = :remotetable_value") + self.assert_compile(s, "SELECT remote_owner.remotetable.rem_id AS remote_owner_remotetable_rem_id, remote_owner.remotetable.datatype_id AS remote_owner_remotetable_datatype_id, remote_owner.remotetable.value "\ + "AS remote_owner_remotetable_value FROM remote_owner.remotetable WHERE "\ + "remote_owner.remotetable.datatype_id = :datatype_id_1 AND remote_owner.remotetable.value = :value_1") - def testalias(self): + def test_alias(self): a = alias(table4, 'remtable') - self.runtest(a.select(a.c.datatype_id==7), "SELECT remtable.rem_id, remtable.datatype_id, remtable.value FROM remote_owner.remotetable AS remtable WHERE remtable.datatype_id = :remtable_datatype_id") - - def testupdate(self): - self.runtest(table4.update(table4.c.value=='test', values={table4.c.datatype_id:12}), "UPDATE remote_owner.remotetable SET datatype_id=:datatype_id WHERE remotetable.value = :remotetable_value") - - def testinsert(self): - self.runtest(table4.insert(values=(2, 5, 'test')), "INSERT INTO remote_owner.remotetable (rem_id, datatype_id, value) VALUES (:rem_id, :datatype_id, :value)") - + self.assert_compile(a.select(a.c.datatype_id==7), "SELECT remtable.rem_id, remtable.datatype_id, remtable.value FROM remote_owner.remotetable AS remtable "\ + "WHERE remtable.datatype_id = :datatype_id_1") + + def test_update(self): + self.assert_compile(table4.update(table4.c.value=='test', values={table4.c.datatype_id:12}), "UPDATE remote_owner.remotetable SET datatype_id=:datatype_id "\ + "WHERE remote_owner.remotetable.value = :value_1") + + def test_insert(self): + self.assert_compile(table4.insert(values=(2, 5, 'test')), "INSERT INTO remote_owner.remotetable (rem_id, datatype_id, value) VALUES "\ + "(:rem_id, :datatype_id, :value)") + if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/sql/selectable.py b/test/sql/selectable.py index dcc8550747..b29ba8d5c0 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -1,227 +1,480 @@ -"""tests that various From objects properly export their columns, as well as -useable primary keys and foreign keys. Full relational algebra depends on -every selectable unit behaving nicely with others..""" - -import testbase -from sqlalchemy import * -from testlib import * - -metadata = MetaData() -table = Table('table1', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', String(20)), - Column('col3', Integer), - Column('colx', Integer), - -) - -table2 = Table('table2', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', Integer, ForeignKey('table1.col1')), - Column('col3', String(20)), - Column('coly', Integer), -) - -class SelectableTest(AssertMixin): - def testdistance(self): - s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')]) - - # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far - #assert s.corresponding_column(table.c.col1) is s.c.col1 - assert s.corresponding_column(s.c.col1) is s.c.col1 - assert s.corresponding_column(s.c.c1) is s.c.c1 - - def testjoinagainstself(self): - jj = select([table.c.col1.label('bar_col1')]) - jjj = join(table, jj, table.c.col1==jj.c.bar_col1) - - # test column directly agaisnt itself - assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 - - assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1 - - # test alias of the join, targets the column with the least - # "distance" between the requested column and the returned column - # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than - # there is from j2.c.bar_col1 to table.c.col1) - j2 = jjj.alias('foo') - assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1 - - - def testjoinagainstjoin(self): - j = outerjoin(table, table2, table.c.col1==table2.c.col2) - jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') - jjj = join(table, jj, table.c.col1==jj.c.bar_col1) - assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 - - j2 = jjj.alias('foo') - print j2.corresponding_column(jjj.c.table1_col1) - assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1 - - assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1 - - def testtablealias(self): - a = table.alias('a') - - j = join(a, table2) - - criterion = a.c.col1 == table2.c.col2 - print - print str(j) - self.assert_(criterion.compare(j.onclause)) - - def testunion(self): - # tests that we can correspond a column in a Select statement with a certain Table, against - # a column in a Union where one of its underlying Selects matches to that same Table - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) - ) - s1 = table.select(use_labels=True) - s2 = table2.select(use_labels=True) - print ["%d %s" % (id(c),c.key) for c in u.c] - c = u.corresponding_column(s1.c.table1_col2) - print "%d %s" % (id(c), c.key) - print id(u.corresponding_column(s1.c.table1_col2).table) - print id(u.c.col2.table) - assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 - assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 - - def testaliasunion(self): - # same as testunion, except its an alias of the union - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) - ).alias('analias') - s1 = table.select(use_labels=True) - s2 = table2.select(use_labels=True) - assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 - assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 - assert u.corresponding_column(s2.c.table2_coly) is u.c.coly - assert s2.corresponding_column(u.c.coly) is s2.c.table2_coly - - def testselectunion(self): - # like testaliasunion, but off a Select off the union. - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) - ).alias('analias') - s = select([u]) - s1 = table.select(use_labels=True) - s2 = table2.select(use_labels=True) - assert s.corresponding_column(s1.c.table1_col2) is s.c.col2 - assert s.corresponding_column(s2.c.table2_col2) is s.c.col2 - - def testunionagainstjoin(self): - # same as testunion, except its an alias of the union - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) - ).alias('analias') - j1 = table.join(table2) - assert u.corresponding_column(j1.c.table1_colx) is u.c.colx - assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx - - def testjoin(self): - a = join(table, table2) - print str(a.select(use_labels=True)) - b = table2.alias('b') - j = join(a, b) - print str(j) - criterion = a.c.table1_col1 == b.c.col2 - self.assert_(criterion.compare(j.onclause)) - - def testselectalias(self): - a = table.select().alias('a') - print str(a.select()) - j = join(a, table2) - - criterion = a.c.col1 == table2.c.col2 - print criterion - print j.onclause - self.assert_(criterion.compare(j.onclause)) - - def testselectlabels(self): - a = table.select(use_labels=True) - print str(a.select()) - j = join(a, table2) - - criterion = a.c.table1_col1 == table2.c.col2 - print - print str(j) - self.assert_(criterion.compare(j.onclause)) - - def testcolumnlabels(self): - a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')]) - print str(a) - print [c for c in a.columns] - print str(a.select()) - j = join(a, table2) - criterion = a.c.acol1 == table2.c.col2 - print str(j) - self.assert_(criterion.compare(j.onclause)) - - def testselectaliaslabels(self): - a = table2.select(use_labels=True).alias('a') - print str(a.select()) - j = join(a, table) - - criterion = table.c.col1 == a.c.table2_col2 - print str(criterion) - print str(j.onclause) - self.assert_(criterion.compare(j.onclause)) - - -class PrimaryKeyTest(AssertMixin): - def test_join_pk_collapse_implicit(self): - """test that redundant columns in a join get 'collapsed' into a minimal primary key, - which is the root column along a chain of foreign key relationships.""" - - meta = MetaData() - a = Table('a', meta, Column('id', Integer, primary_key=True)) - b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True)) - c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True)) - d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True)) - - assert c.c.id.references(b.c.id) - assert not d.c.id.references(a.c.id) - - assert list(a.join(b).primary_key) == [a.c.id] - assert list(b.join(c).primary_key) == [b.c.id] - assert list(a.join(b).join(c).primary_key) == [a.c.id] - assert list(b.join(c).join(d).primary_key) == [b.c.id] - assert list(d.join(c).join(b).primary_key) == [b.c.id] - assert list(a.join(b).join(c).join(d).primary_key) == [a.c.id] - - def test_join_pk_collapse_explicit(self): - """test that redundant columns in a join get 'collapsed' into a minimal primary key, - which is the root column along a chain of explicit join conditions.""" - - meta = MetaData() - a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) - b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer)) - c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True), Column('x', Integer)) - d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True), Column('x', Integer)) - - print list(a.join(b, a.c.x==b.c.id).primary_key) - assert list(a.join(b, a.c.x==b.c.id).primary_key) == [b.c.id] - assert list(b.join(c, b.c.x==c.c.id).primary_key) == [b.c.id] - assert list(a.join(b).join(c, c.c.id==b.c.x).primary_key) == [a.c.id] - assert list(b.join(c, c.c.x==b.c.id).join(d).primary_key) == [c.c.id] - assert list(b.join(c, c.c.id==b.c.x).join(d).primary_key) == [b.c.id] - assert list(d.join(b, d.c.id==b.c.id).join(c, b.c.id==c.c.x).primary_key) == [c.c.id] - assert list(a.join(b).join(c, c.c.id==b.c.x).join(d).primary_key) == [a.c.id] - - assert list(a.join(b, and_(a.c.id==b.c.id, a.c.x==b.c.id)).primary_key) == [a.c.id] - - def test_init_doesnt_blowitaway(self): - meta = MetaData() - a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) - b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer)) - - j = a.join(b) - assert list(j.primary_key) == [a.c.id] - - j.foreign_keys - assert list(j.primary_key) == [a.c.id] - - -if __name__ == "__main__": - testbase.main() - +"""tests that various From objects properly export their columns, as well as +useable primary keys and foreign keys. Full relational algebra depends on +every selectable unit behaving nicely with others..""" + +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from testlib import * +from sqlalchemy.sql import util as sql_util +from sqlalchemy import exceptions + +metadata = MetaData() +table = Table('table1', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(20)), + Column('col3', Integer), + Column('colx', Integer), + +) + +table2 = Table('table2', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', Integer, ForeignKey('table1.col1')), + Column('col3', String(20)), + Column('coly', Integer), +) + +class SelectableTest(TestBase, AssertsExecutionResults): + def testdistance(self): + # same column three times + s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')]) + + # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far + #assert s.corresponding_column(table.c.col1) is s.c.col1 + assert s.corresponding_column(s.c.col1) is s.c.col1 + assert s.corresponding_column(s.c.c1) is s.c.c1 + + def testjoinagainstself(self): + jj = select([table.c.col1.label('bar_col1')]) + jjj = join(table, jj, table.c.col1==jj.c.bar_col1) + + # test column directly agaisnt itself + assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 + + assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1 + + # test alias of the join, targets the column with the least + # "distance" between the requested column and the returned column + # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than + # there is from j2.c.bar_col1 to table.c.col1) + j2 = jjj.alias('foo') + assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1 + + def testselectontable(self): + sel = select([table, table2], use_labels=True) + assert sel.corresponding_column(table.c.col1) is sel.c.table1_col1 + assert sel.corresponding_column(table.c.col1, require_embedded=True) is sel.c.table1_col1 + assert table.corresponding_column(sel.c.table1_col1) is table.c.col1 + assert table.corresponding_column(sel.c.table1_col1, require_embedded=True) is None + + def testjoinagainstjoin(self): + j = outerjoin(table, table2, table.c.col1==table2.c.col2) + jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') + jjj = join(table, jj, table.c.col1==jj.c.bar_col1) + assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 + + j2 = jjj.alias('foo') + print j2.corresponding_column(jjj.c.table1_col1) + assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1 + + assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1 + + def testtablealias(self): + a = table.alias('a') + + j = join(a, table2) + + criterion = a.c.col1 == table2.c.col2 + self.assert_(criterion.compare(j.onclause)) + + def testunion(self): + # tests that we can correspond a column in a Select statement with a certain Table, against + # a column in a Union where one of its underlying Selects matches to that same Table + u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) + ) + s1 = table.select(use_labels=True) + s2 = table2.select(use_labels=True) + print ["%d %s" % (id(c),c.key) for c in u.c] + c = u.corresponding_column(s1.c.table1_col2) + print "%d %s" % (id(c), c.key) + print id(u.corresponding_column(s1.c.table1_col2).table) + print id(u.c.col2.table) + assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 + assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 + + def test_singular_union(self): + u = union(select([table.c.col1, table.c.col2, table.c.col3]), select([table.c.col1, table.c.col2, table.c.col3])) + assert u.oid_column is not None + + u = union(select([table.c.col1, table.c.col2, table.c.col3])) + assert u.oid_column + assert u.c.col1 + assert u.c.col2 + assert u.c.col3 + + def testaliasunion(self): + # same as testunion, except its an alias of the union + u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) + ).alias('analias') + s1 = table.select(use_labels=True) + s2 = table2.select(use_labels=True) + assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 + assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 + assert u.corresponding_column(s2.c.table2_coly) is u.c.coly + assert s2.corresponding_column(u.c.coly) is s2.c.table2_coly + + def testselectunion(self): + # like testaliasunion, but off a Select off the union. + u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) + ).alias('analias') + s = select([u]) + s1 = table.select(use_labels=True) + s2 = table2.select(use_labels=True) + assert s.corresponding_column(s1.c.table1_col2) is s.c.col2 + assert s.corresponding_column(s2.c.table2_col2) is s.c.col2 + + def testunionagainstjoin(self): + # same as testunion, except its an alias of the union + u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) + ).alias('analias') + j1 = table.join(table2) + assert u.corresponding_column(j1.c.table1_colx) is u.c.colx + assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx + + def testjoin(self): + a = join(table, table2) + print str(a.select(use_labels=True)) + b = table2.alias('b') + j = join(a, b) + print str(j) + criterion = a.c.table1_col1 == b.c.col2 + self.assert_(criterion.compare(j.onclause)) + + def testselectalias(self): + a = table.select().alias('a') + print str(a.select()) + j = join(a, table2) + + criterion = a.c.col1 == table2.c.col2 + print criterion + print j.onclause + self.assert_(criterion.compare(j.onclause)) + + def testselectlabels(self): + a = table.select(use_labels=True) + print str(a.select()) + j = join(a, table2) + + criterion = a.c.table1_col1 == table2.c.col2 + print + print str(j) + self.assert_(criterion.compare(j.onclause)) + + def testcolumnlabels(self): + a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')]) + print str(a) + print [c for c in a.columns] + print str(a.select()) + j = join(a, table2) + criterion = a.c.acol1 == table2.c.col2 + print str(j) + self.assert_(criterion.compare(j.onclause)) + + def test_labeled_select_correspoinding(self): + l1 = select([func.max(table.c.col1)]).label('foo') + + s = select([l1]) + assert s.corresponding_column(l1).name == s.c.foo + + s = select([table.c.col1, l1]) + assert s.corresponding_column(l1).name == s.c.foo + + def testselectaliaslabels(self): + a = table2.select(use_labels=True).alias('a') + print str(a.select()) + j = join(a, table) + + criterion = table.c.col1 == a.c.table2_col2 + print str(criterion) + print str(j.onclause) + self.assert_(criterion.compare(j.onclause)) + + def testtablejoinedtoselectoftable(self): + metadata = MetaData() + a = Table('a', metadata, + Column('id', Integer, primary_key=True)) + b = Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + + j1 = a.outerjoin(b) + j2 = select([a.c.id.label('aid')]).alias('bar') + + j3 = a.join(j2, j2.c.aid==a.c.id) + + j4 = select([j3]).alias('foo') + print j4 + print j4.corresponding_column(j2.c.aid) + print j4.c.aid + assert j4.corresponding_column(j2.c.aid) is j4.c.aid + assert j4.corresponding_column(a.c.id) is j4.c.id + + @testing.emits_warning('.*replaced by another column with the same key') + def test_oid(self): + # the oid column of a selectable currently proxies all + # oid columns found within. + s = table.select() + s2 = table2.select() + s3 = select([s, s2]) + assert s3.corresponding_column(table.oid_column) is s3.oid_column + assert s3.corresponding_column(table2.oid_column) is s3.oid_column + assert s3.corresponding_column(s.oid_column) is s3.oid_column + assert s3.corresponding_column(s2.oid_column) is s3.oid_column + + u = s.union(s2) + assert u.corresponding_column(table.oid_column) is u.oid_column + assert u.corresponding_column(table2.oid_column) is u.oid_column + assert u.corresponding_column(s.oid_column) is u.oid_column + assert u.corresponding_column(s2.oid_column) is u.oid_column + + def test_two_metadata_join_raises(self): + m = MetaData() + m2 = MetaData() + + t1 = Table('t1', m, Column('id', Integer), Column('id2', Integer)) + t2 = Table('t2', m, Column('id', Integer, ForeignKey('t1.id'))) + t3 = Table('t3', m2, Column('id', Integer, ForeignKey('t1.id2'))) + + s = select([t2, t3], use_labels=True) + + self.assertRaises(exceptions.NoReferencedTableError, s.join, t1) + +class PrimaryKeyTest(TestBase, AssertsExecutionResults): + def test_join_pk_collapse_implicit(self): + """test that redundant columns in a join get 'collapsed' into a minimal primary key, + which is the root column along a chain of foreign key relationships.""" + + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True)) + c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True)) + d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True)) + + assert c.c.id.references(b.c.id) + assert not d.c.id.references(a.c.id) + + assert list(a.join(b).primary_key) == [a.c.id] + assert list(b.join(c).primary_key) == [b.c.id] + assert list(a.join(b).join(c).primary_key) == [a.c.id] + assert list(b.join(c).join(d).primary_key) == [b.c.id] + assert list(d.join(c).join(b).primary_key) == [b.c.id] + assert list(a.join(b).join(c).join(d).primary_key) == [a.c.id] + + def test_join_pk_collapse_explicit(self): + """test that redundant columns in a join get 'collapsed' into a minimal primary key, + which is the root column along a chain of explicit join conditions.""" + + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer)) + c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True), Column('x', Integer)) + d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True), Column('x', Integer)) + + print list(a.join(b, a.c.x==b.c.id).primary_key) + assert list(a.join(b, a.c.x==b.c.id).primary_key) == [b.c.id] + assert list(b.join(c, b.c.x==c.c.id).primary_key) == [b.c.id] + assert list(a.join(b).join(c, c.c.id==b.c.x).primary_key) == [a.c.id] + assert list(b.join(c, c.c.x==b.c.id).join(d).primary_key) == [c.c.id] + assert list(b.join(c, c.c.id==b.c.x).join(d).primary_key) == [b.c.id] + assert list(d.join(b, d.c.id==b.c.id).join(c, b.c.id==c.c.x).primary_key) == [c.c.id] + assert list(a.join(b).join(c, c.c.id==b.c.x).join(d).primary_key) == [a.c.id] + + assert list(a.join(b, and_(a.c.id==b.c.id, a.c.x==b.c.id)).primary_key) == [a.c.id] + + def test_init_doesnt_blowitaway(self): + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer)) + + j = a.join(b) + assert list(j.primary_key) == [a.c.id] + + j.foreign_keys + assert list(j.primary_key) == [a.c.id] + + def test_non_column_clause(self): + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer, primary_key=True)) + + j = a.join(b, and_(a.c.id==b.c.id, b.c.x==5)) + assert str(j) == "a JOIN b ON a.id = b.id AND b.x = :x_1", str(j) + assert list(j.primary_key) == [a.c.id, b.c.x] + + def test_onclause_direction(self): + metadata = MetaData() + + employee = Table( 'Employee', metadata, + Column('name', String(100)), + Column('id', Integer, primary_key= True), + ) + + engineer = Table( 'Engineer', metadata, + Column('id', Integer, ForeignKey( 'Employee.id', ), primary_key=True), + ) + + self.assertEquals( + set(employee.join(engineer, employee.c.id==engineer.c.id).primary_key), + set([employee.c.id]) + ) + + self.assertEquals( + set(employee.join(engineer, engineer.c.id==employee.c.id).primary_key), + set([employee.c.id]) + ) + + +class ReduceTest(TestBase, AssertsExecutionResults): + def test_reduce(self): + meta = MetaData() + t1 = Table('t1', meta, + Column('t1id', Integer, primary_key=True), + Column('t1data', String(30))) + t2 = Table('t2', meta, + Column('t2id', Integer, ForeignKey('t1.t1id'), primary_key=True), + Column('t2data', String(30))) + t3 = Table('t3', meta, + Column('t3id', Integer, ForeignKey('t2.t2id'), primary_key=True), + Column('t3data', String(30))) + + + self.assertEquals( + set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])), + set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data]) + ) + + def test_reduce_selectable(self): + metadata = MetaData() + + engineers = Table('engineers', metadata, + Column('engineer_id', Integer, primary_key=True), + Column('engineer_name', String(50)), + ) + + managers = Table('managers', metadata, + Column('manager_id', Integer, primary_key=True), + Column('manager_name', String(50)) + ) + + s = select([engineers, managers]).where(engineers.c.engineer_name==managers.c.manager_name) + + self.assertEquals(set(sql_util.reduce_columns(list(s.c), s)), + set([s.c.engineer_id, s.c.engineer_name, s.c.manager_id]) + ) + + def test_reduce_aliased_join(self): + metadata = MetaData() + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('engineer_name', String(50)), + Column('primary_language', String(50)), + ) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('manager_name', String(50)) + ) + + pjoin = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin') + self.assertEquals( + set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])), + set([pjoin.c.people_person_id]) + ) + + def test_reduce_aliased_union(self): + metadata = MetaData() + item_table = Table( + 'item', metadata, + Column('id', Integer, ForeignKey('base_item.id'), primary_key=True), + Column('dummy', Integer, default=0)) + + base_item_table = Table( + 'base_item', metadata, + Column('id', Integer, primary_key=True), + Column('child_name', String(255), default=None)) + + from sqlalchemy.orm.util import polymorphic_union + + item_join = polymorphic_union( { + 'BaseItem':base_item_table.select(base_item_table.c.child_name=='BaseItem'), + 'Item':base_item_table.join(item_table), + }, None, 'item_join') + + self.assertEquals( + set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])), + set([item_join.c.id, item_join.c.dummy, item_join.c.child_name]) + ) + + def test_reduce_aliased_union_2(self): + metadata = MetaData() + + page_table = Table('page', metadata, + Column('id', Integer, primary_key=True), + ) + magazine_page_table = Table('magazine_page', metadata, + Column('page_id', Integer, ForeignKey('page.id'), primary_key=True), + ) + classified_page_table = Table('classified_page', metadata, + Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True), + ) + + from sqlalchemy.orm.util import polymorphic_union + pjoin = polymorphic_union( + { + 'm': page_table.join(magazine_page_table), + 'c': page_table.join(magazine_page_table).join(classified_page_table), + }, None, 'page_join') + + self.assertEquals( + set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])), + set([pjoin.c.id]) + ) + + +class DerivedTest(TestBase, AssertsExecutionResults): + def test_table(self): + meta = MetaData() + t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + + assert t1.is_derived_from(t1) + assert not t2.is_derived_from(t1) + + def test_alias(self): + meta = MetaData() + t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + + assert t1.alias().is_derived_from(t1) + assert not t2.alias().is_derived_from(t1) + assert not t1.is_derived_from(t1.alias()) + assert not t1.is_derived_from(t2.alias()) + + def test_select(self): + meta = MetaData() + t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + + assert t1.select().is_derived_from(t1) + assert not t2.select().is_derived_from(t1) + + assert select([t1, t2]).is_derived_from(t1) + + assert t1.select().alias('foo').is_derived_from(t1) + assert select([t1, t2]).alias('foo').is_derived_from(t1) + assert not t2.select().alias('foo').is_derived_from(t1) + +if __name__ == "__main__": + testenv.main() diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 6590330164..09a3702ee7 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -1,55 +1,36 @@ -import testbase -import pickleable -import datetime, os +import testenv; testenv.configure_for_tests() +import datetime, os, pickleable, re from sqlalchemy import * +from sqlalchemy import exceptions, types, util +from sqlalchemy.sql import operators import sqlalchemy.engine.url as url -from sqlalchemy.databases import mssql, oracle, mysql +from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird from testlib import * -class MyType(types.TypeEngine): - def get_col_spec(self): - return "VARCHAR(100)" - def convert_bind_param(self, value, engine): - return "BIND_IN"+ value - def convert_result_value(self, value, engine): - return value + "BIND_OUT" - def adapt(self, typeobj): - return typeobj() - -class MyDecoratedType(types.TypeDecorator): - impl = String - def convert_bind_param(self, value, dialect): - return "BIND_IN"+ super(MyDecoratedType, self).convert_bind_param(value, dialect) - def convert_result_value(self, value, dialect): - return super(MyDecoratedType, self).convert_result_value(value, dialect) + "BIND_OUT" - def copy(self): - return MyDecoratedType() - -class MyUnicodeType(types.TypeDecorator): - impl = Unicode - def convert_bind_param(self, value, dialect): - return "UNI_BIND_IN"+ super(MyUnicodeType, self).convert_bind_param(value, dialect) - def convert_result_value(self, value, dialect): - return super(MyUnicodeType, self).convert_result_value(value, dialect) + "UNI_BIND_OUT" - def copy(self): - return MyUnicodeType(self.impl.length) - -class AdaptTest(PersistTest): + +class AdaptTest(TestBase): def testadapt(self): e1 = url.URL('postgres').get_dialect()() e2 = url.URL('mysql').get_dialect()() e3 = url.URL('sqlite').get_dialect()() - + e4 = url.URL('firebird').get_dialect()() + type = String(40) - + t1 = type.dialect_impl(e1) t2 = type.dialect_impl(e2) t3 = type.dialect_impl(e3) - assert t1 != t2 - assert t2 != t3 - assert t3 != t1 - + t4 = type.dialect_impl(e4) + + impls = [t1, t2, t3, t4] + for i,ta in enumerate(impls): + for j,tb in enumerate(impls): + if i == j: + assert ta == tb # call me paranoid... :) + else: + assert ta != tb + def testmsnvarchar(self): dialect = mssql.MSSQLDialect() # run the test twice to insure the caching step works too @@ -61,6 +42,11 @@ class AdaptTest(PersistTest): def testoracletext(self): dialect = oracle.OracleDialect() + class MyDecoratedType(types.TypeDecorator): + impl = String + def copy(self): + return MyDecoratedType() + col = Column('', MyDecoratedType) dialect_type = col.type.dialect_impl(dialect) assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl) @@ -81,91 +67,257 @@ class AdaptTest(PersistTest): t2 = mysql.MSVarBinary() assert isinstance(dialect.type_descriptor(t1), mysql.MSVarBinary) assert isinstance(dialect.type_descriptor(t2), mysql.MSVarBinary) - - -class OverrideTest(PersistTest): - """tests user-defined types, including a full type as well as a TypeDecorator""" + + def teststringadapt(self): + """test that String with no size becomes TEXT, *all* others stay as varchar/String""" + + oracle_dialect = oracle.OracleDialect() + mysql_dialect = mysql.MySQLDialect() + postgres_dialect = postgres.PGDialect() + firebird_dialect = firebird.FBDialect() + + for dialect, start, test in [ + (oracle_dialect, String(), oracle.OracleText), + (oracle_dialect, VARCHAR(), oracle.OracleString), + (oracle_dialect, String(50), oracle.OracleString), + (oracle_dialect, Unicode(), oracle.OracleText), + (oracle_dialect, UnicodeText(), oracle.OracleText), + (oracle_dialect, NCHAR(), oracle.OracleString), + (oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw), + (mysql_dialect, String(), mysql.MSText), + (mysql_dialect, VARCHAR(), mysql.MSString), + (mysql_dialect, String(50), mysql.MSString), + (mysql_dialect, Unicode(), mysql.MSText), + (mysql_dialect, UnicodeText(), mysql.MSText), + (mysql_dialect, NCHAR(), mysql.MSNChar), + (postgres_dialect, String(), postgres.PGText), + (postgres_dialect, VARCHAR(), postgres.PGString), + (postgres_dialect, String(50), postgres.PGString), + (postgres_dialect, Unicode(), postgres.PGText), + (postgres_dialect, UnicodeText(), postgres.PGText), + (postgres_dialect, NCHAR(), postgres.PGString), + (firebird_dialect, String(), firebird.FBText), + (firebird_dialect, VARCHAR(), firebird.FBString), + (firebird_dialect, String(50), firebird.FBString), + (firebird_dialect, Unicode(), firebird.FBText), + (firebird_dialect, UnicodeText(), firebird.FBText), + (firebird_dialect, NCHAR(), firebird.FBString), + ]: + assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect)) + + + +class UserDefinedTest(TestBase): + """tests user-defined types.""" def testbasic(self): print users.c.goofy4.type - print users.c.goofy4.type.dialect_impl(testbase.db.dialect) - print users.c.goofy4.type.dialect_impl(testbase.db.dialect).get_col_spec() - + print users.c.goofy4.type.dialect_impl(testing.db.dialect) + print users.c.goofy4.type.dialect_impl(testing.db.dialect).get_col_spec() + def testprocessing(self): global users - users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4='jack') - users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4='lala') - users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4='fred') - + users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12) + users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15) + users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9) + l = users.select().execute().fetchall() - print repr(l) - self.assert_(l == [(2, 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT'), (3, 'BIND_INlalaBIND_OUT', 'BIND_INlalaBIND_OUT', 'BIND_INlalaBIND_OUT', u'UNI_BIND_INlalaUNI_BIND_OUT'), (4, 'BIND_INfredBIND_OUT', 'BIND_INfredBIND_OUT', 'BIND_INfredBIND_OUT', u'UNI_BIND_INfredUNI_BIND_OUT')]) + for assertstr, assertint, assertint2, row in zip( + ["BIND_INjackBIND_OUT", "BIND_INlalaBIND_OUT", "BIND_INfredBIND_OUT"], + [1200, 1500, 900], + [1800, 2250, 1350], + l + + ): + for col in row[1:8]: + self.assertEquals(col, assertstr) + self.assertEquals(row[8], assertint) + self.assertEquals(row[9], assertint2) + for col in (row[4], row[5], row[7]): + assert isinstance(col, unicode) def setUpAll(self): - global users - users = Table('type_users', MetaData(testbase.db), + global users, metadata + + class MyType(types.TypeEngine): + def get_col_spec(self): + return "VARCHAR(100)" + def bind_processor(self, dialect): + def process(value): + return "BIND_IN"+ value + return process + def result_processor(self, dialect): + def process(value): + return value + "BIND_OUT" + return process + def adapt(self, typeobj): + return typeobj() + + class MyDecoratedType(types.TypeDecorator): + impl = String + def bind_processor(self, dialect): + impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value) + def process(value): + return "BIND_IN"+ impl_processor(value) + return process + def result_processor(self, dialect): + impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value) + def process(value): + return impl_processor(value) + "BIND_OUT" + return process + def copy(self): + return MyDecoratedType() + + class MyNewUnicodeType(types.TypeDecorator): + impl = Unicode + + def process_bind_param(self, value, dialect): + return "BIND_IN" + value + + def process_result_value(self, value, dialect): + return value + "BIND_OUT" + + def copy(self): + return MyNewUnicodeType(self.impl.length) + + class MyNewIntType(types.TypeDecorator): + impl = Integer + + def process_bind_param(self, value, dialect): + return value * 10 + + def process_result_value(self, value, dialect): + return value * 10 + + def copy(self): + return MyNewIntType() + + class MyNewIntSubClass(MyNewIntType): + def process_result_value(self, value, dialect): + return value * 15 + + def copy(self): + return MyNewIntSubClass() + + class MyUnicodeType(types.TypeDecorator): + impl = Unicode + + def bind_processor(self, dialect): + impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value) + + def process(value): + return "BIND_IN"+ impl_processor(value) + return process + + def result_processor(self, dialect): + impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value) + def process(value): + return impl_processor(value) + "BIND_OUT" + return process + + def copy(self): + return MyUnicodeType(self.impl.length) + + class LegacyType(types.TypeEngine): + def get_col_spec(self): + return "VARCHAR(100)" + def convert_bind_param(self, value, dialect): + return "BIND_IN"+ value + def convert_result_value(self, value, dialect): + return value + "BIND_OUT" + def adapt(self, typeobj): + return typeobj() + + class LegacyUnicodeType(types.TypeDecorator): + impl = Unicode + + def convert_bind_param(self, value, dialect): + return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect) + + def convert_result_value(self, value, dialect): + return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT" + + def copy(self): + return LegacyUnicodeType(self.impl.length) + + metadata = MetaData(testing.db) + users = Table('type_users', metadata, Column('user_id', Integer, primary_key = True), # totall custom type Column('goofy', MyType, nullable = False), - + # decorated type with an argument, so its a String Column('goofy2', MyDecoratedType(50), nullable = False), - + # decorated type without an argument, it will adapt_args to TEXT Column('goofy3', MyDecoratedType, nullable = False), Column('goofy4', MyUnicodeType, nullable = False), + Column('goofy5', LegacyUnicodeType, nullable = False), + Column('goofy6', LegacyType, nullable = False), + Column('goofy7', MyNewUnicodeType, nullable = False), + Column('goofy8', MyNewIntType, nullable = False), + Column('goofy9', MyNewIntSubClass, nullable = False), ) - - users.create() - def tearDownAll(self): - global users - users.drop() + metadata.create_all() + + def tearDownAll(self): + metadata.drop_all() -class ColumnsTest(AssertMixin): +class ColumnsTest(TestBase, AssertsExecutionResults): def testcolumns(self): expectedResults = { 'int_column': 'int_column INTEGER', 'smallint_column': 'smallint_column SMALLINT', 'varchar_column': 'varchar_column VARCHAR(20)', 'numeric_column': 'numeric_column NUMERIC(12, 3)', - 'float_column': 'float_column NUMERIC(25, 2)' + 'float_column': 'float_column FLOAT(25)', } - db = testbase.db - if not db.name=='sqlite' and not db.name=='oracle': - expectedResults['float_column'] = 'float_column FLOAT(25)' - + db = testing.db + if testing.against('sqlite', 'oracle'): + expectedResults['float_column'] = 'float_column NUMERIC(25, 2)' + + if testing.against('maxdb'): + expectedResults['numeric_column'] = ( + expectedResults['numeric_column'].replace('NUMERIC', 'FIXED')) + print db.engine.__module__ testTable = Table('testColumns', MetaData(db), Column('int_column', Integer), - Column('smallint_column', Smallinteger), + Column('smallint_column', SmallInteger), Column('varchar_column', String(20)), Column('numeric_column', Numeric(12,3)), Column('float_column', Float(25)), ) for aCol in testTable.c: - self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None, None).get_column_specification(aCol)) - -class UnicodeTest(AssertMixin): + self.assertEquals( + expectedResults[aCol.name], + db.dialect.schemagenerator(db.dialect, db, None, None).\ + get_column_specification(aCol)) + +class UnicodeTest(TestBase, AssertsExecutionResults): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" def setUpAll(self): global unicode_table - metadata = MetaData(testbase.db) - unicode_table = Table('unicode_table', metadata, + metadata = MetaData(testing.db) + unicode_table = Table('unicode_table', metadata, Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True), Column('unicode_varchar', Unicode(250)), - Column('unicode_text', Unicode), + Column('unicode_text', UnicodeText), Column('plain_varchar', String(250)) ) unicode_table.create() def tearDownAll(self): unicode_table.drop() + def tearDown(self): + unicode_table.delete().execute() + def testbasic(self): assert unicode_table.c.unicode_varchar.type.length == 250 rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' @@ -174,57 +326,108 @@ class UnicodeTest(AssertMixin): unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - print repr(x['unicode_varchar']) - print repr(x['unicode_text']) - print repr(x['plain_varchar']) + print 0, repr(unicodedata) + print 1, repr(x['unicode_varchar']) + print 2, repr(x['unicode_text']) + print 3, repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) if isinstance(x['plain_varchar'], unicode): # SQLLite and MSSQL return non-unicode data as unicode - self.assert_(testbase.db.name in ('sqlite', 'mssql')) + self.assert_(testing.against('sqlite', 'mssql')) self.assert_(x['plain_varchar'] == unicodedata) - print "it's %s!" % testbase.db.name + print "it's %s!" % testing.db.name else: self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) + def testassert(self): + try: + unicode_table.insert().execute(unicode_varchar='not unicode') + assert False + except exceptions.SAWarning, e: + assert str(e) == "Unicode type received non-unicode bind param value 'not unicode'", str(e) + + unicode_engine = engines.utf8_engine(options={'convert_unicode':True, + 'assert_unicode':True}) + try: + try: + unicode_engine.execute(unicode_table.insert(), plain_varchar='im not unicode') + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'" + + @testing.emits_warning('.*non-unicode bind') + def warns(): + # test that data still goes in if warning is emitted.... + unicode_table.insert().execute(unicode_varchar='not unicode') + assert (select([unicode_table.c.unicode_varchar]).execute().fetchall() == [('not unicode', )]) + warns() + + finally: + unicode_engine.dispose() + + @testing.fails_on('oracle') + def testblanks(self): + unicode_table.insert().execute(unicode_varchar=u'') + assert select([unicode_table.c.unicode_varchar]).scalar() == u'' + def testengineparam(self): """tests engine-wide unicode conversion""" - prev_unicode = testbase.db.engine.dialect.convert_unicode + prev_unicode = testing.db.engine.dialect.convert_unicode + prev_assert = testing.db.engine.dialect.assert_unicode try: - testbase.db.engine.dialect.convert_unicode = True + testing.db.engine.dialect.convert_unicode = True + testing.db.engine.dialect.assert_unicode = False rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') unicode_table.insert().execute(unicode_varchar=unicodedata, unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - print repr(x['unicode_varchar']) - print repr(x['unicode_text']) - print repr(x['plain_varchar']) + print 0, repr(unicodedata) + print 1, repr(x['unicode_varchar']) + print 2, repr(x['unicode_text']) + print 3, repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) finally: - testbase.db.engine.dialect.convert_unicode = prev_unicode + testing.db.engine.dialect.convert_unicode = prev_unicode + testing.db.engine.dialect.convert_unicode = prev_assert @testing.unsupported('oracle') def testlength(self): """checks the database correctly understands the length of a unicode string""" teststr = u'aaa\x1234' - self.assert_(testbase.db.func.length(teststr).scalar() == len(teststr)) - -class BinaryTest(AssertMixin): + self.assert_(testing.db.func.length(teststr).scalar() == len(teststr)) + +class BinaryTest(TestBase, AssertsExecutionResults): def setUpAll(self): - global binary_table - binary_table = Table('binary_table', MetaData(testbase.db), + global binary_table, MyPickleType + + class MyPickleType(types.TypeDecorator): + impl = PickleType + + def process_bind_param(self, value, dialect): + if value: + value.stuff = 'this is modified stuff' + return value + + def process_result_value(self, value, dialect): + if value: + value.stuff = 'this is the right stuff' + return value + + binary_table = Table('binary_table', MetaData(testing.db), Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True), Column('data', Binary), Column('data_slice', Binary(100)), Column('misc', String(30)), # construct PickleType with non-native pickle module, since cPickle uses relative module # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative - # to the 'types' module - Column('pickled', PickleType) + # to the 'types' module + Column('pickled', PickleType), + Column('mypickle', MyPickleType) ) binary_table.create() @@ -237,87 +440,176 @@ class BinaryTest(AssertMixin): def testbinary(self): testobj1 = pickleable.Foo('im foo 1') testobj2 = pickleable.Foo('im foo 2') + testobj3 = pickleable.Foo('im foo 3') stream1 =self.load_stream('binary_data_one.dat') stream2 =self.load_stream('binary_data_two.dat') - binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1) + binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3) binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2) binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None) - + for stmt in ( binary_table.select(order_by=binary_table.c.primary_id), - text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, bind=testbase.db) + text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db) ): l = stmt.execute().fetchall() print type(stream1), type(l[0]['data']), type(l[0]['data_slice']) print len(stream1), len(l[0]['data']), len(l[0]['data_slice']) - self.assert_(list(stream1) == list(l[0]['data'])) - self.assert_(list(stream1[0:100]) == list(l[0]['data_slice'])) - self.assert_(list(stream2) == list(l[1]['data'])) - self.assert_(testobj1 == l[0]['pickled']) - self.assert_(testobj2 == l[1]['pickled']) + self.assertEquals(list(stream1), list(l[0]['data'])) + self.assertEquals(list(stream1[0:100]), list(l[0]['data_slice'])) + self.assertEquals(list(stream2), list(l[1]['data'])) + self.assertEquals(testobj1, l[0]['pickled']) + self.assertEquals(testobj2, l[1]['pickled']) + self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata) + self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff') def load_stream(self, name, len=12579): - f = os.path.join(os.path.dirname(testbase.__file__), name) + f = os.path.join(os.path.dirname(testenv.__file__), name) # put a number less than the typical MySQL default BLOB size return file(f).read(len) - - -class DateTest(AssertMixin): + +class ExpressionTest(TestBase, AssertsExecutionResults): + def setUpAll(self): + global test_table, meta + + class MyCustomType(types.TypeEngine): + def get_col_spec(self): + return "INT" + def bind_processor(self, dialect): + def process(value): + return value * 10 + return process + def result_processor(self, dialect): + def process(value): + return value / 10 + return process + def adapt_operator(self, op): + return {operators.add:operators.sub, operators.sub:operators.add}.get(op, op) + + meta = MetaData(testing.db) + test_table = Table('test', meta, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('atimestamp', Date), + Column('avalue', MyCustomType)) + + meta.create_all() + + test_table.insert().execute({'id':1, 'data':'somedata', 'atimestamp':datetime.date(2007, 10, 15), 'avalue':25}) + + def tearDownAll(self): + meta.drop_all() + + def test_control(self): + assert testing.db.execute("select avalue from test").scalar() == 250 + + assert test_table.select().execute().fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + + def test_bind_adapt(self): + expr = test_table.c.atimestamp == bindparam("thedate") + assert expr.right.type.__class__ == test_table.c.atimestamp.type.__class__ + + assert testing.db.execute(test_table.select().where(expr), {"thedate":datetime.date(2007, 10, 15)}).fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + + expr = test_table.c.avalue == bindparam("somevalue") + assert expr.right.type.__class__ == test_table.c.avalue.type.__class__ + assert testing.db.execute(test_table.select().where(expr), {"somevalue":25}).fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + + + def test_operator_adapt(self): + """test type-based overloading of operators""" + + # test string concatenation + expr = test_table.c.data + "somedata" + assert testing.db.execute(select([expr])).scalar() == "somedatasomedata" + + expr = test_table.c.id + 15 + assert testing.db.execute(select([expr])).scalar() == 16 + + # test custom operator conversion + expr = test_table.c.avalue + 40 + assert expr.type.__class__ is test_table.c.avalue.type.__class__ + + # + operator converted to - + # value is calculated as: (250 - (40 * 10)) / 10 == -15 + assert testing.db.execute(select([expr.label('foo')])).scalar() == -15 + + # this one relies upon anonymous labeling to assemble result + # processing rules on the column. + assert testing.db.execute(select([expr])).scalar() == -15 + +class DateTest(TestBase, AssertsExecutionResults): def setUpAll(self): global users_with_date, insert_data - db = testbase.db - if db.engine.name == 'oracle': + db = testing.db + if testing.against('oracle'): import sqlalchemy.databases.oracle as oracle insert_data = [ - [7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.date(2005,11,10), datetime.datetime(2005, 11, 10, 0, 0, 0, 29384)], - [8, 'roy', datetime.datetime(2005, 11, 10, 11, 52, 35), datetime.date(2005,10,10), datetime.datetime(2006, 5, 10, 15, 32, 47, 6754)], - [9, 'foo', datetime.datetime(2006, 11, 10, 11, 52, 35), datetime.date(1970,4,1), datetime.datetime(2004, 9, 18, 4, 0, 52, 1043)], + [7, 'jack', + datetime.datetime(2005, 11, 10, 0, 0), + datetime.date(2005,11,10), + datetime.datetime(2005, 11, 10, 0, 0, 0, 29384)], + [8, 'roy', + datetime.datetime(2005, 11, 10, 11, 52, 35), + datetime.date(2005,10,10), + datetime.datetime(2006, 5, 10, 15, 32, 47, 6754)], + [9, 'foo', + datetime.datetime(2006, 11, 10, 11, 52, 35), + datetime.date(1970,4,1), + datetime.datetime(2004, 9, 18, 4, 0, 52, 1043)], [10, 'colber', None, None, None] ] + fnames = ['user_id', 'user_name', 'user_datetime', + 'user_date', 'user_time'] + + collist = [Column('user_id', INT, primary_key=True), + Column('user_name', VARCHAR(20)), + Column('user_datetime', DateTime), + Column('user_date', Date), + Column('user_time', TIMESTAMP)] + else: + datetime_micro = 54839 + time_micro = 999 - fnames = ['user_id', 'user_name', 'user_datetime', 'user_date', 'user_time'] - - collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime), - Column('user_date', Date), Column('user_time', TIMESTAMP)] - elif db.engine.name == 'mysql': - # these dont really support the TIME type at all - insert_data = [ - [7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.datetime(2005, 11, 10, 0, 0, 0)], - [8, 'roy', datetime.datetime(2005, 11, 10, 11, 52, 35), datetime.datetime(2006, 5, 10, 15, 32, 47)], - [9, 'foo', datetime.datetime(2005, 11, 10, 11, 52, 35), datetime.datetime(2004, 9, 18, 4, 0, 52)], - [10, 'colber', None, None] - ] - - fnames = ['user_id', 'user_name', 'user_datetime', 'user_date', 'user_time'] + # Missing or poor microsecond support: + if testing.against('mssql', 'mysql', 'firebird'): + datetime_micro, time_micro = 0, 0 + # No microseconds for TIME + elif testing.against('maxdb'): + time_micro = 0 - collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime), - Column('user_date', DateTime)] - else: insert_data = [ - [7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.date(2005,11,10), datetime.time(12,20,2)], - [8, 'roy', datetime.datetime(2005, 11, 10, 11, 52, 35), datetime.date(2005,10,10), datetime.time(0,0,0)], - [9, 'foo', datetime.datetime(2005, 11, 10, 11, 52, 35, 54839), datetime.date(1970,4,1), datetime.time(23,59,59,999)], - [10, 'colber', None, None, None] + [7, 'jack', + datetime.datetime(2005, 11, 10, 0, 0), + datetime.date(2005, 11, 10), + datetime.time(12, 20, 2)], + [8, 'roy', + datetime.datetime(2005, 11, 10, 11, 52, 35), + datetime.date(2005, 10, 10), + datetime.time(0, 0, 0)], + [9, 'foo', + datetime.datetime(2005, 11, 10, 11, 52, 35, datetime_micro), + datetime.date(1970, 4, 1), + datetime.time(23, 59, 59, time_micro)], + [10, 'colber', None, None, None] ] + fnames = ['user_id', 'user_name', 'user_datetime', + 'user_date', 'user_time'] - if db.engine.name == 'mssql': - # MSSQL Datetime values have only a 3.33 milliseconds precision - insert_data[2] = [9, 'foo', datetime.datetime(2005, 11, 10, 11, 52, 35, 547000), datetime.date(1970,4,1), datetime.time(23,59,59,997000)] - - fnames = ['user_id', 'user_name', 'user_datetime', 'user_date', 'user_time'] + collist = [Column('user_id', INT, primary_key=True), + Column('user_name', VARCHAR(20)), + Column('user_datetime', DateTime(timezone=False)), + Column('user_date', Date), + Column('user_time', Time)] - collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime(timezone=False)), - Column('user_date', Date), Column('user_time', Time)] - users_with_date = Table('query_users_with_date', - MetaData(testbase.db), *collist) + MetaData(testing.db), *collist) users_with_date.create() insert_dicts = [dict(zip(fnames, d)) for d in insert_data] for idict in insert_dicts: - users_with_date.insert().execute(**idict) # insert the data + users_with_date.insert().execute(**idict) def tearDownAll(self): users_with_date.drop() @@ -326,20 +618,25 @@ class DateTest(AssertMixin): global insert_data l = map(list, users_with_date.select().execute().fetchall()) - self.assert_(l == insert_data, 'DateTest mismatch: got:%s expected:%s' % (l, insert_data)) + self.assert_(l == insert_data, + 'DateTest mismatch: got:%s expected:%s' % (l, insert_data)) + def testtextdate(self): + x = testing.db.text( + "select user_datetime from query_users_with_date", + typemap={'user_datetime':DateTime}).execute().fetchall() - def testtextdate(self): - x = testbase.db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall() - print repr(x) self.assert_(isinstance(x[0][0], datetime.datetime)) - - #x = db.text("select * from query_users_with_date where user_datetime=:date", bindparams=[bindparam('date', )]).execute(date=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall() - #print repr(x) + + x = testing.db.text( + "select * from query_users_with_date where user_datetime=:somedate", + bindparams=[bindparam('somedate', type_=types.DateTime)]).execute( + somedate=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall() + print repr(x) def testdate2(self): - meta = MetaData(testbase.db) + meta = MetaData(testing.db) t = Table('testdate', meta, Column('id', Integer, Sequence('datetest_id_seq', optional=True), @@ -356,13 +653,56 @@ class DateTest(AssertMixin): self.assert_(x.adate.__class__ == datetime.date) self.assert_(x.adatetime.__class__ == datetime.datetime) + t.delete().execute() + + # test mismatched date/datetime + t.insert().execute(adate=d2, adatetime=d2) + self.assertEquals(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)]) + self.assertEquals(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)]) + finally: t.drop(checkfirst=True) -class NumericTest(AssertMixin): +class StringTest(TestBase, AssertsExecutionResults): + def test_nolen_string_deprecated(self): + metadata = MetaData(testing.db) + foo =Table('foo', metadata, + Column('one', String)) + + # no warning + select([func.count("*")], bind=testing.db).execute() + + try: + # warning during CREATE + foo.create() + assert False + except exceptions.SADeprecationWarning, e: + assert "Using String type with no length" in str(e) + assert re.search(r'\bone\b', str(e)) + + bar = Table('bar', metadata, Column('one', String(40))) + + try: + # no warning + bar.create() + + # no warning for non-lengthed string + select([func.count("*")], from_obj=bar).execute() + finally: + bar.drop() + +def _missing_decimal(): + """Python implementation supports decimals""" + try: + import decimal + return False + except ImportError: + return True + +class NumericTest(TestBase, AssertsExecutionResults): def setUpAll(self): global numeric_table, metadata - metadata = MetaData(testbase.db) + metadata = MetaData(testing.db) numeric_table = Table('numeric_table', metadata, Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True), Column('numericcol', Numeric(asdecimal=False)), @@ -371,17 +711,22 @@ class NumericTest(AssertMixin): Column('fcasdec', Float(asdecimal=True)) ) metadata.create_all() - + def tearDownAll(self): metadata.drop_all() - + def tearDown(self): numeric_table.delete().execute() - + + @testing.fails_if(_missing_decimal) def test_decimal(self): from decimal import Decimal - numeric_table.insert().execute(numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.78) - numeric_table.insert().execute(numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.78")) + numeric_table.insert().execute( + numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75) + numeric_table.insert().execute( + numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), + ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75")) + l = numeric_table.select().execute().fetchall() print l rounded = [ @@ -389,27 +734,39 @@ class NumericTest(AssertMixin): (l[1][0], l[1][1], round(l[1][2], 5), l[1][3], l[1][4]), ] assert rounded == [ - (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")), - (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")), + (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")), + (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")), ] - - -class IntervalTest(AssertMixin): + + @testing.emits_warning('True Decimal types not available') + def test_decimal_fallback(self): + from sqlalchemy.util import Decimal # could be Decimal or float + + numeric_table.insert().execute(ncasdec=12.4, fcasdec=15.75) + numeric_table.insert().execute(ncasdec=Decimal("12.4"), + fcasdec=Decimal("15.75")) + + for row in numeric_table.select().execute().fetchall(): + assert isinstance(row['ncasdec'], util.decimal_type) + assert isinstance(row['fcasdec'], util.decimal_type) + + +class IntervalTest(TestBase, AssertsExecutionResults): def setUpAll(self): global interval_table, metadata - metadata = MetaData(testbase.db) - interval_table = Table("intervaltable", metadata, + metadata = MetaData(testing.db) + interval_table = Table("intervaltable", metadata, Column("id", Integer, Sequence('interval_id_seq', optional=True), primary_key=True), Column("interval", Interval), ) metadata.create_all() - + def tearDown(self): interval_table.delete().execute() - + def tearDownAll(self): metadata.drop_all() - + def test_roundtrip(self): delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17) interval_table.insert().execute(interval=delta) @@ -418,12 +775,12 @@ class IntervalTest(AssertMixin): def test_null(self): interval_table.insert().execute(id=1, inverval=None) assert interval_table.select().execute().fetchone()['interval'] is None - -class BooleanTest(AssertMixin): + +class BooleanTest(TestBase, AssertsExecutionResults): def setUpAll(self): global bool_table - metadata = MetaData(testbase.db) - bool_table = Table('booltest', metadata, + metadata = MetaData(testing.db) + bool_table = Table('booltest', metadata, Column('id', Integer, primary_key=True), Column('value', Boolean)) bool_table.create() @@ -435,14 +792,14 @@ class BooleanTest(AssertMixin): bool_table.insert().execute(id=3, value=True) bool_table.insert().execute(id=4, value=True) bool_table.insert().execute(id=5, value=True) - + res = bool_table.select(bool_table.c.value==True).execute().fetchall() print res assert(res==[(1, True),(3, True),(4, True),(5, True)]) - + res2 = bool_table.select(bool_table.c.value==False).execute().fetchall() print res2 assert(res2==[(2, False)]) if __name__ == "__main__": - testbase.main() + testenv.main() diff --git a/test/sql/unicode.py b/test/sql/unicode.py index f882c2a5f8..9e3ea257e5 100644 --- a/test/sql/unicode.py +++ b/test/sql/unicode.py @@ -1,105 +1,142 @@ # coding: utf-8 """verrrrry basic unicode column name testing""" -import testbase +import testenv; testenv.configure_for_tests() from sqlalchemy import * -from sqlalchemy.orm import mapper, relation, create_session, eagerload from testlib import * +from testlib.engines import utf8_engine +from sqlalchemy.sql import column - -class UnicodeSchemaTest(PersistTest): +class UnicodeSchemaTest(TestBase): + @testing.unsupported('maxdb', 'oracle', 'sybase') def setUpAll(self): - global unicode_bind, metadata, t1, t2 + global unicode_bind, metadata, t1, t2, t3 - unicode_bind = self._unicode_bind() + unicode_bind = utf8_engine() metadata = MetaData(unicode_bind) t1 = Table('unitable1', metadata, Column(u'méil', Integer, primary_key=True), Column(u'\u6e2c\u8a66', Integer), - + test_needs_fk=True, ) t2 = Table(u'Unitéble2', metadata, Column(u'méil', Integer, primary_key=True, key="a"), - Column(u'\u6e2c\u8a66', Integer, ForeignKey(u'unitable1.méil'), key="b"), + Column(u'\u6e2c\u8a66', Integer, ForeignKey(u'unitable1.méil'), + key="b" + ), + test_needs_fk=True, ) + + # Few DBs support Unicode foreign keys + if testing.against('sqlite'): + t3 = Table(u'\u6e2c\u8a66', metadata, + Column(u'\u6e2c\u8a66_id', Integer, primary_key=True, + autoincrement=False), + Column(u'unitable1_\u6e2c\u8a66', Integer, + ForeignKey(u'unitable1.\u6e2c\u8a66') + ), + Column(u'Unitéble2_b', Integer, + ForeignKey(u'Unitéble2.b') + ), + Column(u'\u6e2c\u8a66_self', Integer, + ForeignKey(u'\u6e2c\u8a66.\u6e2c\u8a66_id') + ), + test_needs_fk=True, + ) + else: + t3 = Table(u'\u6e2c\u8a66', metadata, + Column(u'\u6e2c\u8a66_id', Integer, primary_key=True, + autoincrement=False), + Column(u'unitable1_\u6e2c\u8a66', Integer), + Column(u'Unitéble2_b', Integer), + Column(u'\u6e2c\u8a66_self', Integer), + test_needs_fk=True, + ) metadata.create_all() + @testing.unsupported('maxdb', 'oracle', 'sybase') def tearDown(self): - t2.delete().execute() - t1.delete().execute() - + if metadata.tables: + t3.delete().execute() + t2.delete().execute() + t1.delete().execute() + + @testing.unsupported('maxdb', 'oracle', 'sybase') def tearDownAll(self): global unicode_bind metadata.drop_all() del unicode_bind - def _unicode_bind(self): - if testbase.db.name != 'mysql': - return testbase.db - else: - # most mysql installations don't default to utf8 connections - version = testbase.db.dialect.get_version_info(testbase.db) - if version < (4, 1): - raise AssertionError("Unicode not supported on MySQL < 4.1") - - c = testbase.db.connect() - if not hasattr(c.connection.connection, 'set_character_set'): - raise AssertionError( - "Unicode not supported on this MySQL-python version") - else: - c.connection.set_character_set('utf8') - c.detach() - - return c - + @testing.unsupported('maxdb', 'oracle', 'sybase') def test_insert(self): t1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5}) t2.insert().execute({'a':1, 'b':1}) - + t3.insert().execute({u'\u6e2c\u8a66_id': 1, + u'unitable1_\u6e2c\u8a66': 5, + u'Unitéble2_b': 1, + u'\u6e2c\u8a66_self': 1}) + assert t1.select().execute().fetchall() == [(1, 5)] assert t2.select().execute().fetchall() == [(1, 1)] - + assert t3.select().execute().fetchall() == [(1, 5, 1, 1)] + + @testing.unsupported('maxdb', 'oracle', 'sybase') def test_reflect(self): t1.insert().execute({u'méil':2, u'\u6e2c\u8a66':7}) t2.insert().execute({'a':2, 'b':2}) + t3.insert().execute({u'\u6e2c\u8a66_id': 2, + u'unitable1_\u6e2c\u8a66': 7, + u'Unitéble2_b': 2, + u'\u6e2c\u8a66_self': 2}) meta = MetaData(unicode_bind) tt1 = Table(t1.name, meta, autoload=True) tt2 = Table(t2.name, meta, autoload=True) + tt3 = Table(t3.name, meta, autoload=True) tt1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5}) tt2.insert().execute({u'méil':1, u'\u6e2c\u8a66':1}) + tt3.insert().execute({u'\u6e2c\u8a66_id': 1, + u'unitable1_\u6e2c\u8a66': 5, + u'Unitéble2_b': 1, + u'\u6e2c\u8a66_self': 1}) + + self.assert_(tt1.select(order_by=desc(u'méil')).execute().fetchall() == + [(2, 7), (1, 5)]) + self.assert_(tt2.select(order_by=desc(u'méil')).execute().fetchall() == + [(2, 2), (1, 1)]) + self.assert_(tt3.select(order_by=desc(u'\u6e2c\u8a66_id')). + execute().fetchall() == + [(2, 7, 2, 2), (1, 5, 1, 1)]) + meta.drop_all() + metadata.create_all() + +class EscapesDefaultsTest(testing.TestBase): + def test_default_exec(self): + metadata = MetaData(testing.db) + t1 = Table('t1', metadata, + Column(u'special_col', Integer, Sequence('special_col'), primary_key=True), + Column('data', String(50)) # to appease SQLite without DEFAULT VALUES + ) + t1.create() + + try: + engine = metadata.bind + + # reset the identifier preparer, so that we can force it to cache + # a unicode identifier + engine.dialect.identifier_preparer = engine.dialect.preparer(engine.dialect) + select([column(u'special_col')]).select_from(t1).execute() + assert isinstance(engine.dialect.identifier_preparer.format_sequence(Sequence('special_col')), unicode) + + # now execute, run the sequence. it should run in u"Special_col.nextid" or similar as + # a unicode object; cx_oracle asserts that this is None or a String (postgres lets it pass thru). + # ensure that base.DefaultRunner is encoding. + t1.insert().execute(data='foo') + finally: + t1.drop() + - assert tt1.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 7), (1, 5)] - assert tt2.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 2), (1, 1)] - - def test_mapping(self): - # TODO: this test should be moved to the ORM tests, tests should be - # added to this module testing SQL syntax and joins, etc. - class A(object):pass - class B(object):pass - - mapper(A, t1, properties={ - 't2s':relation(B), - 'a':t1.c[u'méil'], - 'b':t1.c[u'\u6e2c\u8a66'] - }) - mapper(B, t2) - sess = create_session() - a1 = A() - b1 = B() - a1.t2s.append(b1) - sess.save(a1) - sess.flush() - sess.clear() - new_a1 = sess.query(A).selectone(t1.c[u'méil'] == a1.a) - assert new_a1.a == a1.a - assert new_a1.t2s[0].a == b1.a - sess.clear() - new_a1 = sess.query(A).options(eagerload('t2s')).selectone(t1.c[u'méil'] == a1.a) - assert new_a1.a == a1.a - assert new_a1.t2s[0].a == b1.a - if __name__ == '__main__': - testbase.main() + testenv.main() diff --git a/test/testbase.py b/test/testbase.py deleted file mode 100644 index 1195db3400..0000000000 --- a/test/testbase.py +++ /dev/null @@ -1,14 +0,0 @@ -"""First import for all test cases, sets sys.path and loads configuration.""" - -__all__ = 'db', - -import sys, os, logging -sys.path.insert(0, os.path.join(os.getcwd(), 'lib')) -logging.basicConfig() - -import testlib.config -testlib.config.configure() - -from testlib.testing import main -db = testlib.config.db - diff --git a/test/testenv.py b/test/testenv.py new file mode 100644 index 0000000000..35e9032aad --- /dev/null +++ b/test/testenv.py @@ -0,0 +1,35 @@ +"""First import for all test cases, sets sys.path and loads configuration.""" + +import sys, os, logging, warnings + +if sys.version_info < (2, 4): + warnings.filterwarnings('ignore', category=FutureWarning) + +from testlib.testing import main +import testlib.config + + +_setup = False + +def configure_for_tests(): + """import testenv; testenv.configure_for_tests()""" + + global _setup + if not _setup: + sys.path.insert(0, os.path.join(os.getcwd(), 'lib')) + logging.basicConfig() + + testlib.config.configure() + _setup = True + +def simple_setup(): + """import testenv; testenv.simple_setup()""" + + global _setup + if not _setup: + sys.path.insert(0, os.path.join(os.getcwd(), 'lib')) + logging.basicConfig() + + testlib.config.configure_defaults() + _setup = True + diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py index ff5c4c125e..98552b0f39 100644 --- a/test/testlib/__init__.py +++ b/test/testlib/__init__.py @@ -5,7 +5,19 @@ Load after sqlalchemy imports to use instrumented stand-ins like Table. import testlib.config from testlib.schema import Table, Column +from testlib.orm import mapper import testlib.testing as testing -from testlib.testing import PersistTest, AssertMixin, ORMTest -import testlib.profiling +from testlib.testing import rowset +from testlib.testing import TestBase, AssertsExecutionResults, ORMTest, AssertsCompiledSQL, ComparesTables +import testlib.profiling as profiling +import testlib.engines as engines +from testlib.compat import set, frozenset, sorted, _function_named + +__all__ = ('testing', + 'mapper', + 'Table', 'Column', + 'rowset', + 'TestBase', 'AssertsExecutionResults', 'ORMTest', 'AssertsCompiledSQL', 'ComparesTables', + 'profiling', 'engines', + 'set', 'frozenset', 'sorted', '_function_named') diff --git a/test/testlib/compat.py b/test/testlib/compat.py new file mode 100644 index 0000000000..ba12b78ac8 --- /dev/null +++ b/test/testlib/compat.py @@ -0,0 +1,92 @@ +import itertools, new, sys, warnings + +__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque' + +try: + set = set +except NameError: + import sets + + # keep this in sync with sqlalchemy.util.Set + # can't just import it in testlib because of coverage, load order, etc. + class set(sets.Set): + def _binary_sanity_check(self, other): + pass + + def issubset(self, iterable): + other = type(self)(iterable) + return sets.Set.issubset(self, other) + def __le__(self, other): + sets.Set._binary_sanity_check(self, other) + return sets.Set.__le__(self, other) + def issuperset(self, iterable): + other = type(self)(iterable) + return sets.Set.issuperset(self, other) + def __ge__(self, other): + sets.Set._binary_sanity_check(self, other) + return sets.Set.__ge__(self, other) + + # lt and gt still require a BaseSet + def __lt__(self, other): + sets.Set._binary_sanity_check(self, other) + return sets.Set.__lt__(self, other) + def __gt__(self, other): + sets.Set._binary_sanity_check(self, other) + return sets.Set.__gt__(self, other) + + def __ior__(self, other): + if not isinstance(other, sets.BaseSet): + return NotImplemented + return sets.Set.__ior__(self, other) + def __iand__(self, other): + if not isinstance(other, sets.BaseSet): + return NotImplemented + return sets.Set.__iand__(self, other) + def __ixor__(self, other): + if not isinstance(other, sets.BaseSet): + return NotImplemented + return sets.Set.__ixor__(self, other) + def __isub__(self, other): + if not isinstance(other, sets.BaseSet): + return NotImplemented + return sets.Set.__isub__(self, other) + +try: + frozenset = frozenset +except NameError: + import sets + from sets import ImmutableSet as frozenset + +try: + sorted = sorted +except NameError: + def sorted(iterable, cmp=None): + l = list(iterable) + if cmp: + l.sort(cmp) + else: + l.sort() + return l + +try: + from collections import deque +except ImportError: + class deque(list): + def appendleft(self, x): + self.insert(0, x) + def popleft(self): + return self.pop(0) + def extendleft(self, iterable): + items = list(iterable) + items.reverse() + for x in items: + self.insert(0, x) + +def _function_named(fn, newname): + try: + fn.__name__ = newname + except: + fn = new.function(fn.func_code, fn.func_globals, newname, + fn.func_defaults, fn.func_closure) + return fn + diff --git a/test/testlib/config.py b/test/testlib/config.py index f05cda46d3..ac9f397177 100644 --- a/test/testlib/config.py +++ b/test/testlib/config.py @@ -1,9 +1,11 @@ -import optparse, os, sys, ConfigParser, StringIO +import optparse, os, sys, re, ConfigParser, StringIO, time, warnings logging, require = None, None + __all__ = 'parser', 'configure', 'options', -db, db_uri, db_type, db_label = None, None, None, None +db = None +db_label, db_url, db_opts = None, None, {} options = None file_config = None @@ -17,7 +19,8 @@ mysql=mysql://scott:tiger@127.0.0.1:3306/test oracle=oracle://scott:tiger@127.0.0.1:1521 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test -firebird=firebird://sysdba:s@localhost/tmp/test.fdb +firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb +maxdb=maxdb://MONA:RED@/maxdb1 """ parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]") @@ -40,6 +43,29 @@ def configure(): return options, file_config +def configure_defaults(): + global options, config + global getopts_options, file_config + global db + + file_config = ConfigParser.ConfigParser() + file_config.readfp(StringIO.StringIO(base_config)) + file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) + (options, args) = parser.parse_args([]) + + # make error messages raised by decorators that depend on a default + # database clearer. + class _engine_bomb(object): + def __getattr__(self, key): + raise RuntimeError('No default engine available, testlib ' + 'was configured with defaults only.') + + db = _engine_bomb() + import testlib.testing + testlib.testing.db = db + + return options, file_config + def _log(option, opt_str, value, parser): global logging if not logging: @@ -70,13 +96,20 @@ def _start_coverage(option, opt_str, value, parser): atexit.register(_stop) coverage.erase() coverage.start() - + def _list_dbs(*args): print "Available --db options (use --dburi to override)" for macro in sorted(file_config.options('db')): print "%20s\t%s" % (macro, file_config.get('db', macro)) sys.exit(0) +def _server_side_cursors(options, opt_str, value, parser): + db_opts['server_side_cursors'] = True + +def _engine_strategy(options, opt_str, value, parser): + if value: + db_opts['strategy'] = value + opt = parser.add_option opt("--verbose", action="store_true", dest="verbose", help="enable stdout echoing/printing") @@ -93,14 +126,23 @@ opt('--dbs', action='callback', callback=_list_dbs, help="List available prefab dbs") opt("--dburi", action="store", dest="dburi", help="Database uri (overrides --db)") +opt("--dropfirst", action="store_true", dest="dropfirst", + help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)") opt("--mockpool", action="store_true", dest="mockpool", help="Use mock pool (asserts only one connection used)") -opt("--enginestrategy", action="store", dest="enginestrategy", default=None, - help="Engine strategy (plain or threadlocal, defaults toplain)") +opt("--enginestrategy", action="callback", type="string", + callback=_engine_strategy, + help="Engine strategy (plain or threadlocal, defaults to plain)") opt("--reversetop", action="store_true", dest="reversetop", default=False, help="Reverse the collection ordering for topological sorts (helps " "reveal dependency issues)") -opt("--serverside", action="store_true", dest="serverside", +opt("--unhashable", action="store_true", dest="unhashable", default=False, + help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") +opt("--noncomparable", action="store_true", dest="noncomparable", default=False, + help="Disallow SQLAlchemy from performing == on mapped test objects.") +opt("--truthless", action="store_true", dest="truthless", default=False, + help="Disallow SQLAlchemy from truth-evaluating mapped test objects.") +opt("--serverside", action="callback", callback=_server_side_cursors, help="Turn on server side cursors for PG") opt("--mysql-engine", action="store", dest="mysql_engine", default=None, help="Use the specified MySQL storage engine for all tables, default is " @@ -130,24 +172,26 @@ class _ordered_map(object): def __iter__(self): for key in self._keys: yield self._data[key] - + +# at one point in refactoring, modules were injecting into the config +# process. this could probably just become a list now. post_configure = _ordered_map() def _engine_uri(options, file_config): - global db_label, db_uri + global db_label, db_url db_label = 'sqlite' if options.dburi: - db_uri = options.dburi - db_label = db_uri[:db_uri.index(':')] + db_url = options.dburi + db_label = db_url[:db_url.index(':')] elif options.db: db_label = options.db - db_uri = None + db_url = None - if db_uri is None: + if db_url is None: if db_label not in file_config.options('db'): raise RuntimeError( "Unknown engine. Specify --dbs for known engines.") - db_uri = file_config.get('db', db_label) + db_url = file_config.get('db', db_label) post_configure['engine_uri'] = _engine_uri def _require(options, file_config): @@ -176,36 +220,52 @@ def _require(options, file_config): pkg_resources.require(requirement) post_configure['require'] = _require -def _create_testing_engine(options, file_config): - from sqlalchemy import engine - global db, db_type - engine_opts = {} - if options.serverside: - engine_opts['server_side_cursors'] = True - - if options.enginestrategy is not None: - engine_opts['strategy'] = options.enginestrategy - +def _engine_pool(options, file_config): if options.mockpool: - db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, - **engine_opts) - else: - db = engine.create_engine(db_uri, **engine_opts) - db_type = db.name - - # decorate the dialect's create_execution_context() method - # to produce a wrapper - from testlib.testing import ExecutionContextWrapper - - create_context = db.dialect.create_execution_context - def create_exec_context(*args, **kwargs): - return ExecutionContextWrapper(create_context(*args, **kwargs)) - db.dialect.create_execution_context = create_exec_context + from sqlalchemy import pool + db_opts['poolclass'] = pool.AssertionPool +post_configure['engine_pool'] = _engine_pool + +def _create_testing_engine(options, file_config): + from testlib import engines, testing + global db + db = engines.testing_engine(db_url, db_opts) + testing.db = db post_configure['create_engine'] = _create_testing_engine +def _prep_testing_database(options, file_config): + from testlib import engines + from sqlalchemy import schema + + try: + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + if not options.quiet: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + except (KeyboardInterrupt, SystemExit): + raise + except Exception, e: + if not options.quiet: + warnings.warn(RuntimeWarning( + "Error checking for existing tables in testing " + "database: %s" % e)) +post_configure['prep_db'] = _prep_testing_database + def _set_table_options(options, file_config): import testlib.schema - + table_options = testlib.schema.table_options for spec in options.tableopts: key, value = spec.split('=') @@ -231,7 +291,7 @@ post_configure['topological'] = _reverse_topological def _set_profile_targets(options, file_config): from testlib import profiling - + profile_config = profiling.profile_config for target in options.profile_targets: diff --git a/test/testlib/engines.py b/test/testlib/engines.py new file mode 100644 index 0000000000..f5694df57e --- /dev/null +++ b/test/testlib/engines.py @@ -0,0 +1,224 @@ +import sys, types, weakref +from testlib import config +from testlib.compat import * + + +class ConnectionKiller(object): + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + + def checkout(self, dbapi_con, con_record, con_proxy): + self.proxy_refs[con_proxy] = True + + def _apply_all(self, methods): + for rec in self.proxy_refs: + if rec is not None and rec.is_valid: + try: + for name in methods: + if callable(name): + name(rec) + else: + getattr(rec, name)() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + # fixme + sys.stderr.write("\n" + str(e) + "\n") + + def rollback_all(self): + self._apply_all(('rollback',)) + + def close_all(self): + self._apply_all(('rollback', 'close')) + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + +testing_reaper = ConnectionKiller() + +def assert_conns_closed(fn): + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + return _function_named(decorated, fn.__name__) + +def rollback_open_connections(fn): + """Decorator that rolls back all open connections after fn execution.""" + + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + return _function_named(decorated, fn.__name__) + +def close_open_connections(fn): + """Decorator that closes all connections after fn execution.""" + + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.close_all() + return _function_named(decorated, fn.__name__) + +class ReconnectFixture(object): + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + self.connections.append(conn) + return conn + + def shutdown(self): + for c in list(self.connections): + c.close() + self.connections = [] + +def reconnecting_engine(url=None, options=None): + url = url or config.db_url + dbapi = config.db.dialect.dbapi + engine = testing_engine(url, {'module':ReconnectFixture(dbapi)}) + engine.test_shutdown = engine.dialect.dbapi.shutdown + return engine + +def testing_engine(url=None, options=None): + """Produce an engine configured by --options with optional overrides.""" + + from sqlalchemy import create_engine + from testlib.testing import ExecutionContextWrapper + + url = url or config.db_url + options = options or config.db_opts + + listeners = options.setdefault('listeners', []) + listeners.append(testing_reaper) + + engine = create_engine(url, **options) + + create_context = engine.dialect.create_execution_context + def create_exec_context(*args, **kwargs): + return ExecutionContextWrapper(create_context(*args, **kwargs)) + engine.dialect.create_execution_context = create_exec_context + return engine + +def utf8_engine(url=None, options=None): + """Hook for dialects or drivers that don't handle utf8 by default.""" + + from sqlalchemy.engine import url as engine_url + + if config.db.name == 'mysql': + dbapi_ver = config.db.dialect.dbapi.version_info + if (dbapi_ver < (1, 2, 1) or + dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), + (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))): + raise RuntimeError('Character set support unavailable with this ' + 'driver version: %s' % repr(dbapi_ver)) + else: + url = url or config.db_url + url = engine_url.make_url(url) + url.query['charset'] = 'utf8' + url.query['use_unicode'] = '0' + url = str(url) + + return testing_engine(url, options) + + +class ReplayableSession(object): + """A simple record/playback tool. + + This is *not* a mock testing class. It only records a session for later + playback and makes no assertions on call consistency whatsoever. It's + unlikely to be suitable for anything other than DB-API recording. + + """ + + Callable = object() + NoAttribute = object() + Natives = set([getattr(types, t) + for t in dir(types) if not t.startswith('_')]). \ + difference([getattr(types, t) + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', 'UnboundMethodType',)]) + def __init__(self): + self.buffer = deque() + + def recorder(self, base): + return self.Recorder(self.buffer, base) + + def player(self): + return self.Player(self.buffer) + + class Recorder(object): + def __init__(self, buffer, subject): + self._buffer = buffer + self._subject = subject + + def __call__(self, *args, **kw): + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + + result = subject(*args, **kw) + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + try: + result = type(subject).__getattribute__(subject, key) + except AttributeError: + buffer.append(ReplayableSession.NoAttribute) + raise + else: + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + class Player(object): + def __init__(self, buffer): + self._buffer = buffer + + def __call__(self, *args, **kw): + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + else: + return result + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + elif result is ReplayableSession.NoAttribute: + raise AttributeError(key) + else: + return result diff --git a/test/testlib/filters.py b/test/testlib/filters.py new file mode 100644 index 0000000000..eb7eff279b --- /dev/null +++ b/test/testlib/filters.py @@ -0,0 +1,239 @@ +"""A collection of Python source transformers. + +Supports the 'clone' command, providing source code transforms to run the test +suite on pre Python 2.4-level parser implementations. + +Includes:: + + py23 + Converts 2.4-level source code into 2.3-parsable source. + Currently only rewrites @decorators, but generator transformations + are possible. + py23_decorators + py23 is currently an alias for py23_decorators. +""" + +import sys +from StringIO import StringIO +from tokenize import * + +__all__ = ['py23_decorators', 'py23'] + + +def py23_decorators(lines): + """Translates @decorators in source lines to 2.3 syntax.""" + + tokens = peekable(generate_tokens(iter(lines).next)) + text = untokenize(backport_decorators(tokens)) + return [x + '\n' for x in text.split('\n')] + +py23 = py23_decorators + + +def backport_decorators(stream): + """Restates @decorators in 2.3 syntax + + Operates on token streams. Converts:: + + @foo + @bar(1, 2) + def quux(): + pass + into:: + + def quux(): + pass + quux = bar(1, 2)(quux) + quux = foo(quux) + + Fails on decorated one-liners:: + + @decorator + def fn(): pass + """ + + if not hasattr(stream, 'peek'): + stream = peekable(iter(stream)) + + stack = [_DecoratorState('')] + emit = [] + for ttype, tok, _, _, _ in stream: + current = stack[-1] + if ttype == INDENT: + current = _DecoratorState(tok) + stack.append(current) + elif ttype == DEDENT: + previous = stack.pop() + assert not previous.decorations + current = stack[-1] + if current.decorations: + ws = pop_trailing_whitespace(emit) + + emit.append((ttype, tok)) + for decorator, misc in reversed(current.decorations): + if not decorator or decorator[0][1] != '@': + emit.extend(decorator) + else: + emit.extend( + [(NAME, current.fn_name), (OP, '=')] + + decorator[1:] + + [(OP, '('), (NAME, current.fn_name), (OP, ')')]) + emit.extend(misc) + current.decorations = [] + emit.extend(ws) + continue + elif ttype == OP and tok == '@': + current.in_decorator = True + decoration = [(ttype, tok)] + current.decorations.append((decoration, [])) + current.consume_identifier(stream) + if stream.peek()[1] == '(': + current.consume_parened(stream) + continue + elif ttype == NAME and tok == 'def': + current.in_decorator = False + current.fn_name = stream.peek()[1] + elif current.in_decorator: + current.append_misc((ttype, tok)) + continue + + emit.append((ttype, tok)) + return emit + +class _DecoratorState(object): + """Holds state for restating decorators as function calls.""" + + in_decorator = False + fn_name = None + def __init__(self, indent): + self.indent = indent + self.decorations = [] + def append_misc(self, token): + if not self.decorations: + self.decorations.append(([], [])) + self.decorations[-1][1].append(token) + def consume_identifier(self, stream): + while True: + typ, value = stream.peek()[:2] + if not (typ == NAME or (typ == OP and value == '.')): + break + self.decorations[-1][0].append(stream.next()[:2]) + def consume_parened(self, stream): + """Consume a (paren) sequence from a token seq starting with (""" + depth, offsets = 0, {'(':1, ')':-1} + while True: + typ, value = stream.next()[:2] + if typ == OP: + depth += offsets.get(value, 0) + self.decorations[-1][0].append((typ, value)) + if depth == 0: + break + +def pop_trailing_whitespace(tokens): + """Removes trailing whitespace tokens from a token list.""" + + popped = [] + for token in reversed(list(tokens)): + if token[0] not in (NL, COMMENT): + break + popped.append(tokens.pop()) + return popped + +def untokenize(iterable): + """Turns a stream of tokens into a Python source str. + + A PEP-8-ish variant of Python 2.5+'s tokenize.untokenize. Produces output + that's not perfect, but is at least readable. The stdlib version is + basically unusable. + """ + + if not hasattr(iterable, 'peek'): + iterable = peekable(iter(iterable)) + + startline = False + indents = [] + toks = [] + toks_append = toks.append + + # this is pretty roughly hacked. i think it could get very close to + # perfect by rewriting to operate over a sliding window of + # (prev, current, next) token sets + making some grouping macros to + # include all the tokens and operators this omits. + for tok in iterable: + toknum, tokval = tok[:2] + + try: + next_num, next_val = iterable.peek()[:2] + except StopIteration: + next_num, next_val = None, None + + if toknum == NAME: + if tokval == 'in': + tokval += ' ' + elif next_num == OP: + if next_val not in ('(', ')', '[', ']', '{', '}', + ':', '.', ',',): + tokval += ' ' + elif next_num != NEWLINE: + tokval += ' ' + elif toknum == OP: + if tokval in ('(', '@', '.', '[', '{', '*', '**'): + pass + elif tokval in ('%', ':') and next_num not in (NEWLINE, ): + tokval += ' ' + elif next_num in (NAME, COMMENT, + NUMBER, STRING): + tokval += ' ' + elif (tokval in (')', ']', '}') and next_num == OP and + '=' in next_val): + tokval += ' ' + elif tokval == ',' or '=' in tokval: + tokval += ' ' + elif toknum in (NUMBER, STRING): + if next_num == OP and next_val not in (')', ']', '}', ',', ':'): + tokval += ' ' + elif next_num == NAME: + tokval += ' ' + + # would be nice to indent continued lines... + if toknum == INDENT: + indents.append(tokval) + continue + elif toknum == DEDENT: + indents.pop() + continue + elif toknum in (NEWLINE, COMMENT, NL): + startline = True + elif startline and indents: + toks_append(indents[-1]) + startline = False + toks_append(tokval) + return ''.join(toks) + + +class peekable(object): + """A iterator wrapper that allows peek()ing at the next value.""" + + def __init__(self, iterator): + self.iterator = iterator + self.buffer = [] + def next(self): + if self.buffer: + return self.buffer.pop(0) + return self.iterator.next() + def peek(self): + if self.buffer: + return self.buffer[0] + x = self.iterator.next() + self.buffer.append(x) + return x + def __iter__(self): + return self + +if __name__ == '__main__': + # runnable. converts a named file to 2.3. + input = open(len(sys.argv) == 2 and sys.argv[1] or __file__) + + tokens = generate_tokens(input.readline) + back = backport_decorators(tokens) + print untokenize(back) diff --git a/test/orm/fixtures.py b/test/testlib/fixtures.py similarity index 72% rename from test/orm/fixtures.py rename to test/testlib/fixtures.py index 4a7d41459f..e8d71179a8 100644 --- a/test/orm/fixtures.py +++ b/test/testlib/fixtures.py @@ -1,54 +1,84 @@ -import testbase +# can't be imported until the path is setup; be sure to configure +# first if covering. from sqlalchemy import * +from sqlalchemy import util from testlib import * +__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest', 'Dingaling', 'item_keywords', + 'dingalings', 'User', 'items', 'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users', + 'order_items', 'Item', 'Order', 'fixtures'] + _recursion_stack = util.Set() class Base(object): def __init__(self, **kwargs): for k in kwargs: setattr(self, k, kwargs[k]) - + + # TODO: add recursion checks to this + def __repr__(self): + return "%s(%s)" % ( + (self.__class__.__name__), + ','.join(["%s=%s" % (key, repr(getattr(self, key))) for key in self.__dict__ if not key.startswith('_')]) + ) + def __ne__(self, other): return not self.__eq__(other) - + def __eq__(self, other): """'passively' compare this object to another. - + only look at attributes that are present on the source object. - + """ - + if self in _recursion_stack: return True _recursion_stack.add(self) try: - # use __dict__ to avoid instrumented properties - for attr in self.__dict__.keys(): + # pick the entity thats not SA persisted as the source + if other is None: + a = self + b = other + elif hasattr(self, '_instance_key'): + a = other + b = self + else: + a = self + b = other + + for attr in a.__dict__.keys(): if attr[0] == '_': continue - value = getattr(self, attr) + value = getattr(a, attr) + #print "looking at attr:", attr, "start value:", value if hasattr(value, '__iter__') and not isinstance(value, basestring): - if len(value) == 0: + try: + # catch AttributeError so that lazy loaders trigger + battr = getattr(b, attr) + except AttributeError: + #print "b class does not have attribute named '%s'" % attr + return False + + if list(value) == list(battr): continue - for (us, them) in zip(value, getattr(other, attr)): - if us != them: - return False else: - continue + return False else: if value is not None: - if value != getattr(other, attr, None): + if value != getattr(b, attr, None): + #print "2. Attribute named '%s' does not match that of b" % attr return False else: return True finally: _recursion_stack.remove(self) - + class User(Base):pass class Order(Base):pass class Item(Base):pass class Keyword(Base):pass class Address(Base):pass +class Dingaling(Base):pass metadata = MetaData() @@ -64,12 +94,18 @@ orders = Table('orders', metadata, Column('isopen', Integer) ) -addresses = Table('addresses', metadata, +addresses = Table('addresses', metadata, Column('id', Integer, primary_key=True), Column('user_id', None, ForeignKey('users.id')), Column('email_address', String(50), nullable=False)) -items = Table('items', metadata, +dingalings = Table("dingalings", metadata, + Column('id', Integer, primary_key=True), + Column('address_id', None, ForeignKey('addresses.id')), + Column('data', String(30)) + ) + +items = Table('items', metadata, Column('id', Integer, primary_key=True), Column('description', String(30), nullable=False) ) @@ -78,11 +114,11 @@ order_items = Table('order_items', metadata, Column('item_id', None, ForeignKey('items.id')), Column('order_id', None, ForeignKey('orders.id'))) -item_keywords = Table('item_keywords', metadata, +item_keywords = Table('item_keywords', metadata, Column('item_id', None, ForeignKey('items.id')), Column('keyword_id', None, ForeignKey('keywords.id'))) -keywords = Table('keywords', metadata, +keywords = Table('keywords', metadata, Column('id', Integer, primary_key=True), Column('name', String(30), nullable=False) ) @@ -102,12 +138,16 @@ def install_fixture_data(): dict(id = 4, user_id = 8, email_address = "ed@lala.com"), dict(id = 5, user_id = 9, email_address = "fred@fred.com"), ) + dingalings.insert().execute( + dict(id=1, address_id=2, data='ding 1/2'), + dict(id=2, address_id=5, data='ding 2/5'), + ) orders.insert().execute( dict(id = 1, user_id = 7, description = 'order 1', isopen=0, address_id=1), dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4), dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1), dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4), - dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1) + dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=None) ) items.insert().execute( dict(id=1, description='item 1'), @@ -146,7 +186,7 @@ def install_fixture_data(): # this many-to-many table has the keywords inserted # in primary key order, to appease the unit tests. - # this is because postgres, oracle, and sqlite all support + # this is because postgres, oracle, and sqlite all support # true insert-order row id, but of course our pal MySQL does not, # so the best it can do is order by, well something, so there you go. item_keywords.insert().execute( @@ -161,21 +201,37 @@ def install_fixture_data(): dict(keyword_id=6, item_id=3) ) +class FixtureTest(ORMTest): + refresh_data = False + + def setUpAll(self): + super(FixtureTest, self).setUpAll() + if self.keep_data: + install_fixture_data() + + def setUp(self): + if self.refresh_data: + install_fixture_data() + + def define_tables(self, meta): + pass +FixtureTest.metadata = metadata + class Fixtures(object): @property def user_address_result(self): return [ User(id=7, addresses=[ Address(id=1) - ]), + ]), User(id=8, addresses=[ Address(id=2, email_address='ed@wood.com'), Address(id=3, email_address='ed@bettyboop.com'), Address(id=4, email_address='ed@lala.com'), - ]), + ]), User(id=9, addresses=[ Address(id=5) - ]), + ]), User(id=10, addresses=[]) ] @@ -188,18 +244,18 @@ class Fixtures(object): Order(description='order 1', items=[Item(description='item 1'), Item(description='item 2'), Item(description='item 3')]), Order(description='order 3'), Order(description='order 5'), - ]), + ]), User(id=8, addresses=[ Address(id=2), Address(id=3), Address(id=4) - ]), + ]), User(id=9, addresses=[ Address(id=5) ], orders=[ Order(description='order 2', items=[Item(description='item 1'), Item(description='item 2'), Item(description='item 3')]), Order(description='order 4', items=[Item(description='item 1'), Item(description='item 5')]), - ]), + ]), User(id=10, addresses=[]) ] @@ -217,8 +273,8 @@ class Fixtures(object): Order(id=4, items=[Item(id=1), Item(id=5)]), ]), User(id=10) - ] - + ] + @property def item_keyword_result(self): return [ diff --git a/test/testlib/orm.py b/test/testlib/orm.py new file mode 100644 index 0000000000..d0ec155e67 --- /dev/null +++ b/test/testlib/orm.py @@ -0,0 +1,115 @@ +import inspect, re +from testlib import config +orm = None + +__all__ = 'mapper', + + +_whitespace = re.compile(r'^(\s+)') + +def _find_pragma(lines, current): + m = _whitespace.match(lines[current]) + basis = m and m.group() or '' + + for line in reversed(lines[0:current]): + if 'testlib.pragma' in line: + return line + m = _whitespace.match(line) + indent = m and m.group() or '' + + # simplistic detection: + + # >> # testlib.pragma foo + # >> center_line() + if indent == basis: + break + # >> # testlib.pragma foo + # >> if fleem: + # >> center_line() + if line.endswith(':'): + break + return None + +def _make_blocker(method_name, fallback): + """Creates tripwired variant of a method, raising when called. + + To excempt an invocation from blockage, there are two options. + + 1) add a pragma in a comment:: + + # testlib.pragma exempt:methodname + offending_line() + + 2) add a magic cookie to the function's namespace:: + __sa_baremethodname_exempt__ = True + ... + offending_line() + another_offending_lines() + + The second is useful for testing and development. + """ + + if method_name.startswith('__') and method_name.endswith('__'): + frame_marker = '__sa_%s_exempt__' % method_name[2:-2] + else: + frame_marker = '__sa_%s_exempt__' % method_name + pragma_marker = 'exempt:' + method_name + + def method(self, *args, **kw): + frame_r = None + try: + frame = inspect.stack()[1][0] + frame_r = inspect.getframeinfo(frame, 9) + + module = frame.f_globals.get('__name__', '') + + type_ = type(self) + + pragma = _find_pragma(*frame_r[3:5]) + + exempt = ( + (not module.startswith('sqlalchemy')) or + (pragma and pragma_marker in pragma) or + (frame_marker in frame.f_locals) or + ('self' in frame.f_locals and + getattr(frame.f_locals['self'], frame_marker, False))) + + if exempt: + supermeth = getattr(super(type_, self), method_name, None) + if (supermeth is None or + getattr(supermeth, 'im_func', None) is method): + return fallback(self, *args, **kw) + else: + return supermeth(*args, **kw) + else: + raise AssertionError( + "%s.%s called in %s, line %s in %s" % ( + type_.__name__, method_name, module, frame_r[1], frame_r[2])) + finally: + del frame + method.__name__ = method_name + return method + +def mapper(type_, *args, **kw): + global orm + if orm is None: + from sqlalchemy import orm + + forbidden = [ + ('__hash__', 'unhashable', lambda s: id(s)), + ('__eq__', 'noncomparable', lambda s, o: s is o), + ('__ne__', 'noncomparable', lambda s, o: s is not o), + ('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)), + ('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)), + ('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)), + ('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)), + ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)), + ('__nonzero__', 'truthless', lambda s: 1), ] + + if type_.__bases__ == (object,): + for method_name, option, fallback in forbidden: + if (getattr(config.options, option, False) and + method_name not in type_.__dict__): + setattr(type_, method_name, _make_blocker(method_name, fallback)) + + return orm.mapper(type_, *args, **kw) diff --git a/test/testlib/profiling.py b/test/testlib/profiling.py index 697df4ea2e..b452d1fb82 100644 --- a/test/testlib/profiling.py +++ b/test/testlib/profiling.py @@ -1,23 +1,26 @@ """Profiling support for unit and performance tests.""" +import os, sys from testlib.config import parser, post_configure +from testlib.compat import * import testlib.config -__all__ = 'profiled', +__all__ = 'profiled', 'function_call_count', 'conditional_call_count' all_targets = set() profile_config = { 'targets': set(), 'report': True, 'sort': ('time', 'calls'), 'limit': None } +profiler = None -def profiled(target, **target_opts): +def profiled(target=None, **target_opts): """Optional function profiling. @profiled('label') or @profiled('label', report=True, sort=('calls',), limit=20) - + Enables profiling for a function when 'label' is targetted for profiling. Report options can be supplied, and override the global configuration and command-line options. @@ -26,7 +29,9 @@ def profiled(target, **target_opts): import time, hotshot, hotshot.stats # manual or automatic namespacing by module would remove conflict issues - if target in all_targets: + if target is None: + target = 'anonymous_target' + elif target in all_targets: print "Warning: redefining profile target '%s'" % target all_targets.add(target) @@ -38,37 +43,168 @@ def profiled(target, **target_opts): not target_opts.get('always', None)): return fn(*args, **kw) - prof = hotshot.Profile(filename) - began = time.time() - prof.start() - try: - result = fn(*args, **kw) - finally: - prof.stop() - ended = time.time() - prof.close() + elapsed, load_stats, result = _profile( + filename, fn, *args, **kw) if not testlib.config.options.quiet: print "Profiled target '%s', wall time: %.2f seconds" % ( - target, ended - began) + target, elapsed) report = target_opts.get('report', profile_config['report']) - if report: + if report and testlib.config.options.verbose: sort_ = target_opts.get('sort', profile_config['sort']) limit = target_opts.get('limit', profile_config['limit']) print "Profile report for target '%s' (%s)" % ( target, filename) - stats = hotshot.stats.load(filename) + stats = load_stats() stats.sort_stats(*sort_) if limit: stats.print_stats(limit) else: stats.print_stats() + + os.unlink(filename) return result - try: - profiled.__name__ = fn.__name__ - except: - pass - return profiled + return _function_named(profiled, fn.__name__) + return decorator + +def function_call_count(count=None, versions={}, variance=0.05): + """Assert a target for a test case's function call count. + + count + Optional, general target function call count. + + versions + Optional, a dictionary of Python version strings to counts, + for example:: + + { '2.5.1': 110, + '2.5': 100, + '2.4': 150 } + + The best match for the current running python will be used. + If none match, 'count' will be used as the fallback. + + variance + An +/- deviation percentage, defaults to 5%. + """ + + # this could easily dump the profile report if --verbose is in effect + + version_info = list(sys.version_info) + py_version = '.'.join([str(v) for v in sys.version_info]) + + while version_info: + version = '.'.join([str(v) for v in version_info]) + if version in versions: + count = versions[version] + break + version_info.pop() + + if count is None: + return lambda fn: fn + + def decorator(fn): + def counted(*args, **kw): + try: + filename = "%s.prof" % fn.__name__ + + elapsed, stat_loader, result = _profile( + filename, fn, *args, **kw) + + stats = stat_loader() + calls = stats.total_calls + + if testlib.config.options.verbose: + stats.sort_stats('calls', 'cumulative') + stats.print_stats() + + deviance = int(count * variance) + if (calls < (count - deviance) or + calls > (count + deviance)): + raise AssertionError( + "Function call count %s not within %s%% " + "of expected %s. (Python version %s)" % ( + calls, (variance * 100), count, py_version)) + + return result + finally: + if os.path.exists(filename): + os.unlink(filename) + return _function_named(counted, fn.__name__) + return decorator + +def conditional_call_count(discriminator, categories): + """Apply a function call count conditionally at runtime. + + Takes two arguments, a callable that returns a key value, and a dict + mapping key values to a tuple of arguments to function_call_count. + + The callable is not evaluated until the decorated function is actually + invoked. If the `discriminator` returns a key not present in the + `categories` dictionary, no call count assertion is applied. + + Useful for integration tests, where running a named test in isolation may + have a function count penalty not seen in the full suite, due to lazy + initialization in the DB-API, SA, etc. + """ + + def decorator(fn): + def at_runtime(*args, **kw): + criteria = categories.get(discriminator(), None) + if criteria is None: + return fn(*args, **kw) + + rewrapped = function_call_count(*criteria)(fn) + return rewrapped(*args, **kw) + return _function_named(at_runtime, fn.__name__) return decorator + + +def _profile(filename, fn, *args, **kw): + global profiler + if not profiler: + profiler = 'hotshot' + if sys.version_info > (2, 5): + try: + import cProfile + profiler = 'cProfile' + except ImportError: + pass + + if profiler == 'cProfile': + return _profile_cProfile(filename, fn, *args, **kw) + else: + return _profile_hotshot(filename, fn, *args, **kw) + +def _profile_cProfile(filename, fn, *args, **kw): + import cProfile, gc, pstats, time + + load_stats = lambda: pstats.Stats(filename) + gc.collect() + + began = time.time() + cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), + filename=filename) + ended = time.time() + + return ended - began, load_stats, locals()['result'] + +def _profile_hotshot(filename, fn, *args, **kw): + import gc, hotshot, hotshot.stats, time + load_stats = lambda: hotshot.stats.load(filename) + + gc.collect() + prof = hotshot.Profile(filename) + began = time.time() + prof.start() + try: + result = fn(*args, **kw) + finally: + prof.stop() + ended = time.time() + prof.close() + + return ended - began, load_stats, result + diff --git a/test/testlib/schema.py b/test/testlib/schema.py index a2fc912650..37f3591ade 100644 --- a/test/testlib/schema.py +++ b/test/testlib/schema.py @@ -1,5 +1,6 @@ -import testbase -from sqlalchemy import schema +from testlib import testing +import itertools +schema = None __all__ = 'Table', 'Column', @@ -8,21 +9,70 @@ table_options = {} def Table(*args, **kw): """A schema.Table wrapper/hook for dialect-specific tweaks.""" + global schema + if schema is None: + from sqlalchemy import schema + test_opts = dict([(k,kw.pop(k)) for k in kw.keys() if k.startswith('test_')]) kw.update(table_options) - if testbase.db.name == 'mysql': + if testing.against('mysql'): if 'mysql_engine' not in kw and 'mysql_type' not in kw: if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: kw['mysql_engine'] = 'InnoDB' + # Apply some default cascading rules for self-referential foreign keys. + # MySQL InnoDB has some issues around seleting self-refs too. + if testing.against('firebird'): + table_name = args[0] + unpack = (testing.config.db.dialect. + identifier_preparer.unformat_identifiers) + + # Only going after ForeignKeys in Columns. May need to + # expand to ForeignKeyConstraint too. + fks = [fk + for col in args if isinstance(col, schema.Column) + for fk in col.args if isinstance(fk, schema.ForeignKey)] + + for fk in fks: + # root around in raw spec + ref = fk._colspec + if isinstance(ref, schema.Column): + name = ref.table.name + else: + # take just the table name: on FB there cannot be + # a schema, so the first element is always the + # table name, possibly followed by the field name + name = unpack(ref)[0] + print name, table_name + if name == table_name: + if fk.ondelete is None: + fk.ondelete = 'CASCADE' + if fk.onupdate is None: + fk.onupdate = 'CASCADE' + + if testing.against('oracle'): + pk_seqs = [col for col in args if isinstance(col, schema.Column) + and col.primary_key and getattr(col, '_needs_autoincrement', False)] + for c in pk_seqs: + c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True)) return schema.Table(*args, **kw) + def Column(*args, **kw): """A schema.Column wrapper/hook for dialect-specific tweaks.""" - # TODO: a Column that creates a Sequence automatically for PK columns, - # which would help Oracle tests - return schema.Column(*args, **kw) + global schema + if schema is None: + from sqlalchemy import schema + + test_opts = dict([(k,kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + c = schema.Column(*args, **kw) + if testing.against('oracle'): + if 'test_needs_autoincrement' in test_opts: + c._needs_autoincrement = True + return c diff --git a/test/testlib/tables.py b/test/testlib/tables.py index 69c84c5b34..33b1b20db9 100644 --- a/test/testlib/tables.py +++ b/test/testlib/tables.py @@ -1,10 +1,13 @@ -import testbase +# can't be imported until the path is setup; be sure to configure +# first if covering. from sqlalchemy import * +from testlib import testing from testlib.schema import Table, Column -# these are older test fixtures, used primarily by test/orm/mapper.py and test/orm/unitofwork.py. -# newer unit tests make usage of test/orm/fixtures.py. +# these are older test fixtures, used primarily by test/orm/mapper.py and +# test/orm/unitofwork.py. newer unit tests make usage of +# test/orm/fixtures.py. metadata = MetaData() @@ -39,7 +42,7 @@ keywords = Table('keywords', metadata, Column('name', VARCHAR(50)), ) -userkeywords = Table('userkeywords', metadata, +userkeywords = Table('userkeywords', metadata, Column('user_id', INT, ForeignKey("users")), Column('keyword_id', INT, ForeignKey("keywords")), ) @@ -52,18 +55,18 @@ itemkeywords = Table('itemkeywords', metadata, def create(): if not metadata.bind: - metadata.bind = testbase.db + metadata.bind = testing.db metadata.create_all() def drop(): if not metadata.bind: - metadata.bind = testbase.db + metadata.bind = testing.db metadata.drop_all() def delete(): for t in metadata.table_iterator(reverse=True): t.delete().execute() def user_data(): if not metadata.bind: - metadata.bind = testbase.db + metadata.bind = testing.db users.insert().execute( dict(user_id = 7, user_name = 'jack'), dict(user_id = 8, user_name = 'ed'), @@ -71,10 +74,10 @@ def user_data(): ) def delete_user_data(): users.delete().execute() - + def data(): delete() - + # with SQLITE, the OID column of a table defaults to the primary key, if it has one. # so to database-neutrally get rows back in "insert order" based on OID, we # have to also put the primary keys in order for the purpose of these tests @@ -112,10 +115,10 @@ def data(): dict(keyword_id=6, name='round'), dict(keyword_id=7, name='square') ) - + # this many-to-many table has the keywords inserted # in primary key order, to appease the unit tests. - # this is because postgres, oracle, and sqlite all support + # this is because postgres, oracle, and sqlite all support # true insert-order row id, but of course our pal MySQL does not, # so the best it can do is order by, well something, so there you go. itemkeywords.insert().execute( @@ -132,8 +135,11 @@ def data(): class BaseObject(object): def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, ",".join("%s=%s" % (k, repr(v)) for k, v in self.__dict__.iteritems() if k[0] != '_')) - + return "%s(%s)" % (self.__class__.__name__, + ",".join(["%s=%s" % (k, repr(v)) + for k, v in self.__dict__.iteritems() + if k[0] != '_'])) + class User(BaseObject): def __init__(self): self.user_id = None @@ -147,7 +153,7 @@ class Order(BaseObject): class Item(BaseObject): pass - + class Keyword(BaseObject): pass @@ -159,33 +165,33 @@ user_address_result = [ {'user_id' : 9, 'addresses' : (Address, [])} ] -user_address_orders_result = [{'user_id' : 7, +user_address_orders_result = [{'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}]), 'orders' : (Order, [{'order_id' : 1}, {'order_id' : 3},{'order_id' : 5},]) }, - {'user_id' : 8, + {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]), 'orders' : (Order, []) }, - {'user_id' : 9, + {'user_id' : 9, 'addresses' : (Address, []), 'orders' : (Order, [{'order_id' : 2},{'order_id' : 4}]) }] user_all_result = [ -{'user_id' : 7, +{'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}]), 'orders' : (Order, [ - {'order_id' : 1, 'items': (Item, [])}, + {'order_id' : 1, 'items': (Item, [])}, {'order_id' : 3, 'items': (Item, [{'item_id':3, 'item_name':'item 3'}, {'item_id':4, 'item_name':'item 4'}, {'item_id':5, 'item_name':'item 5'}])}, {'order_id' : 5, 'items': (Item, [])}, ]) }, -{'user_id' : 8, +{'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]), 'orders' : (Order, []) }, -{'user_id' : 9, +{'user_id' : 9, 'addresses' : (Address, []), 'orders' : (Order, [ {'order_id' : 2, 'items': (Item, [{'item_id':1, 'item_name':'item 1'}, {'item_id':2, 'item_name':'item 2'}])}, @@ -215,4 +221,3 @@ order_result = [ {'order_id' : 4, 'items':(Item, [])}, {'order_id' : 5, 'items':(Item, [])}, ] - diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 213772e9e1..cf0936e922 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -2,17 +2,143 @@ # monkeypatches unittest.TestLoader.suiteClass at import time -import unittest, re, sys, os +import itertools, os, operator, re, sys, unittest, warnings from cStringIO import StringIO -from sqlalchemy import MetaData, sql -from sqlalchemy.orm import clear_mappers import testlib.config as config +from testlib.compat import * -__all__ = 'PersistTest', 'AssertMixin', 'ORMTest' +sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None +sa_exceptions = None + +__all__ = ('TestBase', 'AssertsExecutionResults', 'ComparesTables', 'ORMTest', 'AssertsCompiledSQL') + +_ops = { '<': operator.lt, + '>': operator.gt, + '==': operator.eq, + '!=': operator.ne, + '<=': operator.le, + '>=': operator.ge, + 'in': operator.contains, + 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + } + +# sugar ('testing.db'); set here by config() at runtime +db = None + +def fails_if(callable_): + """Mark a test as expected to fail if callable_ returns True. + + If the callable returns false, the test is run and reported as normal. + However if the callable returns true, the test is expected to fail and the + unit test logic is inverted: if the test fails, a success is reported. If + the test succeeds, a failure is reported. + """ + + docstring = getattr(callable_, '__doc__', None) or callable_.__name__ + description = docstring.split('\n')[0] + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if not callable_(): + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected (condition: %s): %s " % ( + fn_name, description, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' (condition: %s)" % + (fn_name, description)) + return _function_named(maybe, fn_name) + return decorate + + +def future(fn): + """Mark a test as expected to unconditionally fail. + + Takes no arguments, omit parens when using as a decorator. + """ + + fn_name = fn.__name__ + def decorated(*args, **kw): + try: + fn(*args, **kw) + except Exception, ex: + print ("Future test '%s' failed as expected: %s " % ( + fn_name, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for future test '%s'" % fn_name) + return _function_named(decorated, fn_name) + +def fails_on(*dbs): + """Mark a test as expected to fail on one or more database implementations. + + Unlike ``unsupported``, tests marked as ``fails_on`` will be run + for the named databases. The test is expected to fail and the unit test + logic is inverted: if the test fails, a success is reported. If the test + succeeds, a failure is reported. + """ + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if config.db.name not in dbs: + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected on DB implementation " + "'%s': %s" % ( + fn_name, config.db.name, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' on DB implementation '%s'" % + (fn_name, config.db.name)) + return _function_named(maybe, fn_name) + return decorate + +def fails_on_everything_except(*dbs): + """Mark a test as expected to fail on most database implementations. + + Like ``fails_on``, except failure is the expected outcome on all + databases except those listed. + """ + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if config.db.name in dbs: + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected on DB implementation " + "'%s': %s" % ( + fn_name, config.db.name, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' on DB implementation '%s'" % + (fn_name, config.db.name)) + return _function_named(maybe, fn_name) + return decorate def unsupported(*dbs): - """Mark a test as unsupported by one or more database implementations""" - + """Mark a test as unsupported by one or more database implementations. + + 'unsupported' tests will be skipped unconditionally. Useful for feature + tests that cause deadlocks or other fatal problems. + """ + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): @@ -22,40 +148,204 @@ def unsupported(*dbs): return True else: return fn(*args, **kw) - try: - maybe.__name__ = fn_name - except: - pass - return maybe + return _function_named(maybe, fn_name) return decorate -def supported(*dbs): - """Mark a test as supported by one or more database implementations""" - +def exclude(db, op, spec): + """Mark a test as unsupported by specific database server versions. + + Stackable, both with other excludes and other decorators. Examples:: + + # Not supported by mydb versions less than 1, 0 + @exclude('mydb', '<', (1,0)) + # Other operators work too + @exclude('bigdb', '==', (9,0,9)) + @exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) + """ + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name in dbs: + if _is_excluded(db, op, spec): + print "'%s' unsupported on DB %s version '%s'" % ( + fn_name, config.db.name, _server_version()) + return True + else: return fn(*args, **kw) + return _function_named(maybe, fn_name) + return decorate + +def _is_excluded(db, op, spec): + """Return True if the configured db matches an exclusion specification. + + db: + A dialect name + op: + An operator or stringified operator, such as '==' + spec: + A value that will be compared to the dialect's server_version_info + using the supplied operator. + + Examples:: + # Not supported by mydb versions less than 1, 0 + _is_excluded('mydb', '<', (1,0)) + # Other operators work too + _is_excluded('bigdb', '==', (9,0,9)) + _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) + """ + + if config.db.name != db: + return False + + version = _server_version() + + oper = hasattr(op, '__call__') and op or _ops[op] + return oper(version, spec) + +def _server_version(bind=None): + """Return a server_version_info tuple.""" + + if bind is None: + bind = config.db + return bind.dialect.server_version_info(bind.contextual_connect()) + +def emits_warning(*messages): + """Mark a test as emitting a warning. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + + # TODO: it would be nice to assert that a named warning was + # emitted. should work with some monkeypatching of warnings, + # and may work on non-CPython if they keep to the spirit of + # warnings.showwarning's docstring. + # - update: jython looks ok, it uses cpython's module + def decorate(fn): + def safe(*args, **kw): + global sa_exceptions + if sa_exceptions is None: + import sqlalchemy.exceptions as sa_exceptions + + if not messages: + filters = [dict(action='ignore', + category=sa_exceptions.SAWarning)] else: - print "'%s' unsupported on DB implementation '%s'" % ( - fn_name, config.db.name) - return True - try: - maybe.__name__ = fn_name - except: - pass - return maybe + filters = [dict(action='ignore', + message=message, + category=sa_exceptions.SAWarning) + for message in messages ] + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return _function_named(safe, fn.__name__) + return decorate + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + """ + + def decorate(fn): + def safe(*args, **kw): + global sa_exceptions + if sa_exceptions is None: + import sqlalchemy.exceptions as sa_exceptions + + if not messages: + filters = [dict(action='ignore', + category=sa_exceptions.SADeprecationWarning)] + else: + filters = [dict(action='ignore', + message=message, + category=sa_exceptions.SADeprecationWarning) + for message in + [ (m.startswith('//') and + ('Call to deprecated function ' + m[2:]) or m) + for m in messages] ] + + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return _function_named(safe, fn.__name__) return decorate +def resetwarnings(): + """Reset warning behavior to testing defaults.""" + + global sa_exceptions + if sa_exceptions is None: + import sqlalchemy.exceptions as sa_exceptions + + warnings.resetwarnings() + warnings.filterwarnings('error', category=sa_exceptions.SADeprecationWarning) + warnings.filterwarnings('error', category=sa_exceptions.SAWarning) + + if sys.version_info < (2, 4): + warnings.filterwarnings('ignore', category=FutureWarning) + + +def against(*queries): + """Boolean predicate, compares to testing database configuration. + + Given one or more dialect names, returns True if one is the configured + database engine. + + Also supports comparison to database version when provided with one or + more 3-tuples of dialect name, operator, and version specification:: + + testing.against('mysql', 'postgres') + testing.against(('mysql', '>=', (5, 0, 0)) + """ + + for query in queries: + if isinstance(query, basestring): + if config.db.name == query: + return True + else: + name, op, spec = query + if config.db.name != name: + continue + + have = config.db.dialect.server_version_info( + config.db.contextual_connect()) + + oper = hasattr(op, '__call__') and op or _ops[op] + if oper(have, spec): + return True + return False + +def rowset(results): + """Converts the results of sql execution into a plain set of column tuples. + + Useful for asserting the results of an unordered query. + """ + + return set([tuple(row) for row in results]) + + class TestData(object): """Tracks SQL expressions as they are executed via an instrumented ExecutionContext.""" - + def __init__(self): self.set_assert_list(None, None) self.sql_count = 0 self.buffer = None - + def set_assert_list(self, unittest, list): self.unittest = unittest self.assert_list = list @@ -68,18 +358,25 @@ testdata = TestData() class ExecutionContextWrapper(object): """instruments the ExecutionContext created by the Engine so that SQL expressions can be tracked.""" - + def __init__(self, ctx): + global sql + if sql is None: + from sqlalchemy import sql + self.__dict__['ctx'] = ctx def __getattr__(self, key): return getattr(self.ctx, key) def __setattr__(self, key, value): setattr(self.ctx, key, value) - + + trailing_underscore_pattern = re.compile(r'(\W:[\w_#]+)_\b',re.MULTILINE) def post_execution(self): ctx = self.ctx statement = unicode(ctx.compiled) statement = re.sub(r'\n', '', ctx.statement) + if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'): + statement = statement[:-25] if testdata.buffer is not None: testdata.buffer.write(statement + "\n") @@ -90,9 +387,9 @@ class ExecutionContextWrapper(object): item = testdata.assert_list.pop() else: # asserting a dictionary of statements->parameters - # this is to specify query assertions where the queries can be in + # this is to specify query assertions where the queries can be in # multiple orderings - if not item.has_key('_converted'): + if '_converted' not in item: for key in item.keys(): ckey = self.convert_statement(key) item[ckey] = item[key] @@ -110,21 +407,26 @@ class ExecutionContextWrapper(object): (query, params) = item if callable(params): params = params(ctx) - if params is not None and isinstance(params, list) and len(params) == 1: - params = params[0] - - if isinstance(ctx.compiled_parameters, sql.ClauseParameters): - parameters = ctx.compiled_parameters.get_original_dict() - elif isinstance(ctx.compiled_parameters, list): - parameters = [p.get_original_dict() for p in ctx.compiled_parameters] - + if params is not None and not isinstance(params, list): + params = [params] + + parameters = ctx.compiled_parameters + query = self.convert_statement(query) - if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'): - statement = statement[:-25] - testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) + equivalent = ( (statement == query) + or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) ) + ) \ + and \ + ( (params is None) or (params == parameters) + or params == [dict([(k.strip('_'), v) + for (k, v) in p.items()]) + for p in parameters] + ) + testdata.unittest.assert_(equivalent, + "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) testdata.sql_count += 1 self.ctx.post_execution() - + def convert_statement(self, query): paramstyle = self.ctx.dialect.paramstyle if paramstyle == 'named': @@ -143,7 +445,17 @@ class ExecutionContextWrapper(object): query = re.sub(r':([\w_]+)', repl, query) return query -class PersistTest(unittest.TestCase): +class TestBase(unittest.TestCase): + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None def __init__(self, *args, **params): unittest.TestCase.__init__(self, *args, **params) @@ -157,11 +469,79 @@ class PersistTest(unittest.TestCase): def shortDescription(self): """overridden to not return docstrings""" return None + + def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs): + try: + callable_(*args, **kwargs) + assert False, "Callable did not raise expected exception" + except except_cls, e: + assert re.search(msg, str(e)), "Exception message did not match: '%s'" % str(e) + + if not hasattr(unittest.TestCase, 'assertTrue'): + assertTrue = unittest.TestCase.failUnless + if not hasattr(unittest.TestCase, 'assertFalse'): + assertFalse = unittest.TestCase.failIf + +class AssertsCompiledSQL(object): + def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None): + if dialect is None: + dialect = getattr(self, '__dialect__', None) + + if params is None: + keys = None + else: + keys = params.keys() + + c = clause.compile(column_keys=keys, dialect=dialect) + + print "\nSQL String:\n" + str(c) + repr(c.params) + + cc = re.sub(r'\n', '', str(c)) + + self.assertEquals(cc, result) + + if checkparams is not None: + self.assertEquals(c.construct_params(params), checkparams) + +class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table): + global sqltypes, schema + if sqltypes is None: + import sqlalchemy.types as sqltypes + if schema is None: + import sqlalchemy.schema as schema + base_mro = sqltypes.TypeEngine.__mro__ + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + self.assertEquals(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + self.assertEquals(c.primary_key, reflected_c.primary_key) + self.assertEquals(c.nullable, reflected_c.nullable) + assert len( + set(type(reflected_c.type).__mro__).difference(base_mro).intersection( + set(type(c.type).__mro__).difference(base_mro) + ) + ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + + if isinstance(c.type, sqltypes.String): + self.assertEquals(c.type.length, reflected_c.type.length) + + self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + if c.default: + assert isinstance(reflected_c.default, schema.PassiveDefault) + elif against(('mysql', '<', (5, 0))): + # ignore reflection of bogus db-generated PassiveDefault() + pass + elif not c.primary_key or not against('postgres'): + print repr(c) + assert reflected_c.default is None, reflected_c.default + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] -class AssertMixin(PersistTest): - """given a list-based structure of keys/properties which represent information within an object structure, and - a list of actual objects, asserts that the list of objects corresponds to the structure.""" +class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): result = list(result) print repr(result) @@ -173,7 +553,7 @@ class AssertMixin(PersistTest): "for class " + class_.__name__) for i in range(0, len(list)): self.assert_row(class_, result[i], list[i]) - + def assert_row(self, class_, rowobj, desc): self.assert_(rowobj.__class__ is class_, "item class is not " + repr(class_)) @@ -187,12 +567,62 @@ class AssertMixin(PersistTest): self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % ( key, getattr(rowobj, key), value)) - + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + global util + if util is None: + from sqlalchemy import util + + class frozendict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = set([frozendict(e) for e in expected]) + + for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found): + self.fail('Unexpected type "%s", expected "%s"' % ( + type(wrong).__name__, cls.__name__)) + + if len(found) != len(expected): + self.fail('Unexpected object count "%s", expected "%s"' % ( + len(found), len(expected))) + + NOVALUE = object() + def _compare_item(obj, spec): + for key, value in spec.iteritems(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1]) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + self.fail( + "Expected %s instance with attributes %s not found." % ( + cls.__name__, repr(expected_item))) + return True + def assert_sql(self, db, callable_, list, with_sequences=None): global testdata testdata = TestData() - if with_sequences is not None and (config.db.name == 'postgres' or - config.db.name == 'oracle'): + if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): testdata.set_assert_list(self, with_sequences) else: testdata.set_assert_list(self, list) @@ -204,12 +634,10 @@ class AssertMixin(PersistTest): def assert_sql_count(self, db, callable_, count): global testdata testdata = TestData() - try: - callable_() - finally: - self.assert_(testdata.sql_count == count, - "desired statement count %d does not match %d" % ( - count, testdata.sql_count)) + callable_() + self.assert_(testdata.sql_count == count, + "desired statement count %d does not match %d" % ( + count, testdata.sql_count)) def capture_sql(self, db, callable_): global testdata @@ -223,20 +651,34 @@ class AssertMixin(PersistTest): testdata.buffer = None _otest_metadata = None -class ORMTest(AssertMixin): +class ORMTest(TestBase, AssertsExecutionResults): keep_mappers = False keep_data = False + metadata = None def setUpAll(self): - global _otest_metadata - _otest_metadata = MetaData(config.db) + global MetaData, _otest_metadata + + if MetaData is None: + from sqlalchemy import MetaData + + if self.metadata is None: + _otest_metadata = MetaData(config.db) + else: + _otest_metadata = self.metadata + if self.metadata.bind is None: + _otest_metadata.bind = config.db self.define_tables(_otest_metadata) _otest_metadata.create_all() + self.setup_mappers() self.insert_data() def define_tables(self, _otest_metadata): raise NotImplementedError() - + + def setup_mappers(self): + pass + def insert_data(self): pass @@ -244,22 +686,37 @@ class ORMTest(AssertMixin): return _otest_metadata def tearDownAll(self): + global clear_mappers + if clear_mappers is None: + from sqlalchemy.orm import clear_mappers + clear_mappers() _otest_metadata.drop_all() def tearDown(self): + global Session + if Session is None: + from sqlalchemy.orm.session import Session + Session.close_all() + global clear_mappers + if clear_mappers is None: + from sqlalchemy.orm import clear_mappers + if not self.keep_mappers: clear_mappers() if not self.keep_data: for t in _otest_metadata.table_iterator(reverse=True): - t.delete().execute().close() + try: + t.delete().execute().close() + except Exception, e: + print "EXCEPTION DELETING...", e class TTestSuite(unittest.TestSuite): """A TestSuite with once per TestCase setUpAll() and tearDownAll()""" def __init__(self, tests=()): - if len(tests) >0 and isinstance(tests[0], PersistTest): + if len(tests) > 0 and isinstance(tests[0], TestBase): self._initTest = tests[0] else: self._initTest = None @@ -278,20 +735,48 @@ class TTestSuite(unittest.TestSuite): return self(result) def __call__(self, result): + init = getattr(self, '_initTest', None) + if init is not None: + if (hasattr(init, '__unsupported_on__') and + config.db.name in init.__unsupported_on__): + print "'%s' unsupported on DB implementation '%s'" % ( + init.__class__.__name__, config.db.name) + return True + if (getattr(init, '__only_on__', None) not in (None,config.db.name)): + print "'%s' unsupported on DB implementation '%s'" % ( + init.__class__.__name__, config.db.name) + return True + if (getattr(init, '__skip_if__', False)): + for c in getattr(init, '__skip_if__'): + if c(): + print "'%s' skipped by %s" % ( + init.__class__.__name__, c.__name__) + return True + for rule in getattr(init, '__excluded_on__', ()): + if _is_excluded(*rule): + print "'%s' unsupported on DB %s version %s" % ( + init.__class__.__name__, config.db.name, + _server_version()) + return True + try: + resetwarnings() + init.setUpAll() + except: + # skip tests if global setup fails + ex = self.__exc_info() + for test in self._tests: + result.addError(test, ex) + return False try: - if self._initTest is not None: - self._initTest.setUpAll() - except: - result.addError(self._initTest, self.__exc_info()) - pass - try: + resetwarnings() return self.do_run(result) finally: try: - if self._initTest is not None: - self._initTest.tearDownAll() + resetwarnings() + if init is not None: + init.tearDownAll() except: - result.addError(self._initTest, self.__exc_info()) + result.addError(init, self.__exc_info()) pass def __exc_info(self): @@ -305,30 +790,9 @@ class TTestSuite(unittest.TestSuite): return (exctype, excvalue, tb) return (exctype, excvalue, tb) +# monkeypatch unittest.TestLoader.suiteClass = TTestSuite -def _iter_covered_files(): - import sqlalchemy - for rec in os.walk(os.path.dirname(sqlalchemy.__file__)): - for x in rec[2]: - if x.endswith('.py'): - yield os.path.join(rec[0], x) - -def cover(callable_, file_=None): - from testlib import coverage - coverage_client = coverage.the_coverage - coverage_client.get_ready() - coverage_client.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]') - coverage_client.erase() - coverage_client.start() - try: - return callable_() - finally: - coverage_client.stop() - coverage_client.save() - coverage_client.report(list(_iter_covered_files()), - show_missing=False, ignore_errors=False, - file=file_) class DevNullWriter(object): def write(self, msg): @@ -352,7 +816,7 @@ def runTests(suite): def main(suite=None): if not suite: - if len(sys.argv[1:]): + if sys.argv[1:]: suite =unittest.TestLoader().loadTestsFromNames( sys.argv[1:], __import__('__main__')) else: diff --git a/test/zblog/alltests.py b/test/zblog/alltests.py index ed430ac7ee..53947daa14 100644 --- a/test/zblog/alltests.py +++ b/test/zblog/alltests.py @@ -1,4 +1,4 @@ -import testbase +import testenv; testenv.configure_for_tests() import unittest def suite(): @@ -15,4 +15,4 @@ def suite(): if __name__ == '__main__': - testbase.main(suite()) + testenv.main(suite()) diff --git a/test/zblog/blog.py b/test/zblog/blog.py index e234bbbc76..9e48a202f0 100644 --- a/test/zblog/blog.py +++ b/test/zblog/blog.py @@ -1,11 +1,12 @@ __all__ = ['Blog', 'Post', 'Topic', 'TopicAssociation', 'Comment'] import datetime +from testlib.compat import * class Blog(object): def __init__(self, owner=None): self.owner = owner - + class Post(object): topics = set def __init__(self, user=None, headline=None, summary=None): @@ -15,7 +16,7 @@ class Post(object): self.summary = summary self.comments = [] self.comment_count = 0 - + class Topic(object): def __init__(self, keyword=None, description=None): self.keyword = keyword @@ -26,11 +27,9 @@ class TopicAssociation(object): self.post = post self.topic = topic self.is_primary = is_primary - + class Comment(object): def __init__(self, subject=None, body=None): self.subject = subject self.datetime = datetime.datetime.today() self.body = body - - diff --git a/test/zblog/mappers.py b/test/zblog/mappers.py index 11eaf4fd04..0d789f3d0b 100644 --- a/test/zblog/mappers.py +++ b/test/zblog/mappers.py @@ -8,9 +8,9 @@ from sqlalchemy.orm import * import sqlalchemy.util as util def zblog_mappers(): - # User mapper. Here, we redefine the names of some of the columns - # to different property names. normally the table columns are all - # sucked in automatically. + # User mapper. Here, we redefine the names of some of the columns to + # different property names. normally the table columns are all sucked in + # automatically. mapper(user.User, tables.users, properties={ 'id':tables.users.c.user_id, 'name':tables.users.c.user_name, @@ -18,30 +18,33 @@ def zblog_mappers(): 'crypt_password':tables.users.c.password, }) - # blog mapper. this contains a reference to the user mapper, - # and also installs a "backreference" on that relationship to handle it - # in both ways. this will also attach a 'blogs' property to the user mapper. + # blog mapper. this contains a reference to the user mapper, and also + # installs a "backreference" on that relationship to handle it in both + # ways. this will also attach a 'blogs' property to the user mapper. mapper(Blog, tables.blogs, properties={ 'id':tables.blogs.c.blog_id, - 'owner':relation(user.User, lazy=False, backref=backref('blogs', cascade="all, delete-orphan")), + 'owner':relation(user.User, lazy=False, + backref=backref('blogs', cascade="all, delete-orphan")), }) # topic mapper. map all topic columns to the Topic class. mapper(Topic, tables.topics) - - # TopicAssocation mapper. This is an "association" object, which is similar to - # a many-to-many relationship except extra data is associated with each pair - # of related data. because the topic_xref table doesnt have a primary key, - # the "primary key" columns of a TopicAssociation are defined manually here. - mapper(TopicAssociation,tables.topic_xref, - primary_key=[tables.topic_xref.c.post_id, tables.topic_xref.c.topic_id], + + # TopicAssocation mapper. This is an "association" object, which is + # similar to a many-to-many relationship except extra data is associated + # with each pair of related data. because the topic_xref table doesnt + # have a primary key, the "primary key" columns of a TopicAssociation are + # defined manually here. + mapper(TopicAssociation,tables.topic_xref, + primary_key=[tables.topic_xref.c.post_id, + tables.topic_xref.c.topic_id], properties={ 'topic':relation(Topic, lazy=False), }) - # Post mapper, these are posts within a blog. - # since we want the count of comments for each post, create a select that will get the posts - # and count the comments in one query. + # Post mapper, these are posts within a blog. + # since we want the count of comments for each post, create a select that + # will get the posts and count the comments in one query. posts_with_ccount = select( [c for c in tables.posts.c if c.key != 'body'] + [ func.count(tables.comments.c.comment_id).label('comment_count') @@ -54,41 +57,60 @@ def zblog_mappers(): ] ) .alias('postswcount') - # then create a Post mapper on that query. - # we have the body as "deferred" so that it loads only when needed, - # the user as a Lazy load, since the lazy load will run only once per user and - # its usually only one user's posts is needed per page, - # the owning blog is a lazy load since its also probably loaded into the identity map - # already, and topics is an eager load since that query has to be done per post in any - # case. + # then create a Post mapper on that query. + # we have the body as "deferred" so that it loads only when needed, the + # user as a Lazy load, since the lazy load will run only once per user and + # its usually only one user's posts is needed per page, the owning blog is + # a lazy load since its also probably loaded into the identity map + # already, and topics is an eager load since that query has to be done per + # post in any case. mapper(Post, posts_with_ccount, properties={ 'id':posts_with_ccount.c.post_id, 'body':deferred(tables.posts.c.body), - 'user':relation(user.User, lazy=True, backref=backref('posts', cascade="all, delete-orphan")), - 'blog':relation(Blog, lazy=True, backref=backref('posts', cascade="all, delete-orphan")), - 'topics':relation(TopicAssociation, lazy=False, private=True, association=Topic, backref='post') + 'user':relation(user.User, lazy=True, + backref=backref('posts', cascade="all, delete-orphan")), + 'blog':relation(Blog, lazy=True, + backref=backref('posts', cascade="all, delete-orphan")), + 'topics':relation(TopicAssociation, lazy=False, + cascade="all, delete-orphan", + backref='post') }, order_by=[desc(posts_with_ccount.c.datetime)]) - # comment mapper. This mapper is handling a hierarchical relationship on itself, and contains - # a lazy reference both to its parent comment and its list of child comments. + # comment mapper. This mapper is handling a hierarchical relationship on + # itself, and contains a lazy reference both to its parent comment and its + # list of child comments. mapper(Comment, tables.comments, properties={ 'id':tables.comments.c.comment_id, - 'post':relation(Post, lazy=True, backref=backref('comments', cascade="all, delete-orphan")), - 'user':relation(user.User, lazy=False, backref=backref('comments', cascade="all, delete-orphan")), - 'parent':relation(Comment, primaryjoin=tables.comments.c.parent_comment_id==tables.comments.c.comment_id, foreignkey=tables.comments.c.comment_id, lazy=True, uselist=False), - 'replies':relation(Comment,primaryjoin=tables.comments.c.parent_comment_id==tables.comments.c.comment_id, lazy=True, uselist=True, cascade="all"), + 'post':relation(Post, lazy=True, + backref=backref('comments', + cascade="all, delete-orphan")), + 'user':relation(user.User, lazy=False, + backref=backref('comments', + cascade="all, delete-orphan")), + 'parent':relation(Comment, + primaryjoin=(tables.comments.c.parent_comment_id == + tables.comments.c.comment_id), + foreign_keys=[tables.comments.c.comment_id], + lazy=True, uselist=False), + 'replies':relation(Comment, + primaryjoin=(tables.comments.c.parent_comment_id == + tables.comments.c.comment_id), + lazy=True, uselist=True, cascade="all"), }) -# we define one special find-by for the comments of a post, which is going to make its own "noload" -# mapper and organize the comments into their correct hierarchy in one pass. hierarchical -# data normally needs to be loaded by separate queries for each set of children, unless you -# use a proprietary extension like CONNECT BY. +# we define one special find-by for the comments of a post, which is going to +# make its own "noload" mapper and organize the comments into their correct +# hierarchy in one pass. hierarchical data normally needs to be loaded by +# separate queries for each set of children, unless you use a proprietary +# extension like CONNECT BY. def find_by_post(post): - """returns a hierarchical collection of comments based on a given criterion. - uses a mapper that does not lazy load replies or parents, and instead + """returns a hierarchical collection of comments based on a given criterion. + + Uses a mapper that does not lazy load replies or parents, and instead organizes comments into a hierarchical tree when the result is produced. """ + q = session().query(Comment).options(noload('replies'), noload('parent')) comments = q.select_by(post_id=post.id) result = [] @@ -112,4 +134,3 @@ def start_session(): def session(): return trans.session - diff --git a/test/zblog/tables.py b/test/zblog/tables.py index 5b4054a195..408762e451 100644 --- a/test/zblog/tables.py +++ b/test/zblog/tables.py @@ -6,50 +6,49 @@ from testlib import * metadata = MetaData() -users = Table('users', metadata, - Column('user_id', Integer, primary_key=True), +users = Table('users', metadata, + Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True), Column('user_name', String(30), nullable=False), Column('fullname', String(100), nullable=False), Column('password', String(40), nullable=False), Column('groupname', String(20), nullable=False), ) -blogs = Table('blogs', metadata, - Column('blog_id', Integer, primary_key=True), +blogs = Table('blogs', metadata, + Column('blog_id', Integer, Sequence('blog_id_seq', optional=True), primary_key=True), Column('owner_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('name', String(100), nullable=False), Column('description', String(500)) ) - + posts = Table('posts', metadata, - Column('post_id', Integer, primary_key=True), + Column('post_id', Integer, Sequence('post_id_seq', optional=True), primary_key=True), Column('blog_id', Integer, ForeignKey('blogs.blog_id'), nullable=False), Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('datetime', DateTime, nullable=False), Column('headline', String(500)), - Column('summary', String), - Column('body', String), + Column('summary', Text), + Column('body', Text), ) - + topics = Table('topics', metadata, Column('topic_id', Integer, primary_key=True), Column('keyword', String(50), nullable=False), Column('description', String(500)) ) - -topic_xref = Table('topic_post_xref', metadata, + +topic_xref = Table('topic_post_xref', metadata, Column('topic_id', Integer, ForeignKey('topics.topic_id'), nullable=False), Column('is_primary', Boolean, nullable=False), Column('post_id', Integer, ForeignKey('posts.post_id'), nullable=False) ) -comments = Table('comments', metadata, +comments = Table('comments', metadata, Column('comment_id', Integer, primary_key=True), Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('post_id', Integer, ForeignKey('posts.post_id'), nullable=False), Column('datetime', DateTime, nullable=False), Column('parent_comment_id', Integer, ForeignKey('comments.comment_id')), Column('subject', String(500)), - Column('body', String), + Column('body', Text), ) - diff --git a/test/zblog/tests.py b/test/zblog/tests.py index ad6876937d..4f77d350d3 100644 --- a/test/zblog/tests.py +++ b/test/zblog/tests.py @@ -1,5 +1,4 @@ -import testbase - +import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * @@ -8,14 +7,14 @@ from zblog.user import * from zblog.blog import * -class ZBlogTest(AssertMixin): +class ZBlogTest(TestBase, AssertsExecutionResults): def create_tables(self): - tables.metadata.drop_all(bind=testbase.db) - tables.metadata.create_all(bind=testbase.db) + tables.metadata.drop_all(bind=testing.db) + tables.metadata.create_all(bind=testing.db) def drop_tables(self): - tables.metadata.drop_all(bind=testbase.db) - + tables.metadata.drop_all(bind=testing.db) + def setUpAll(self): self.create_tables() def tearDownAll(self): @@ -31,7 +30,7 @@ class SavePostTest(ZBlogTest): super(SavePostTest, self).setUpAll() mappers.zblog_mappers() global blog_id, user_id - s = create_session(bind=testbase.db) + s = create_session(bind=testing.db) user = User('zbloguser', "Zblog User", "hello", group=administrator) blog = Blog(owner=user) blog.name = "this is a blog" @@ -45,12 +44,12 @@ class SavePostTest(ZBlogTest): def tearDownAll(self): clear_mappers() super(SavePostTest, self).tearDownAll() - + def testattach(self): """test that a transient/pending instance has proper bi-directional behavior. - + this requires that lazy loaders do not fire off for a transient/pending instance.""" - s = create_session(bind=testbase.db) + s = create_session(bind=testing.db) s.begin() try: @@ -62,12 +61,12 @@ class SavePostTest(ZBlogTest): assert post in blog.posts finally: s.rollback() - + def testoptimisticorphans(self): - """test that instances in the session with un-loaded parents will not + """test that instances in the session with un-loaded parents will not get marked as "orphans" and then deleted """ - s = create_session(bind=testbase.db) - + s = create_session(bind=testing.db) + s.begin() try: blog = s.query(Blog).get(blog_id) @@ -79,6 +78,7 @@ class SavePostTest(ZBlogTest): s.flush() s.clear() + user = s.query(User).get(user_id) blog = s.query(Blog).get(blog_id) post = blog.posts[0] comment = Comment(subject="some subject", body="some body") @@ -86,14 +86,12 @@ class SavePostTest(ZBlogTest): comment.user = user s.flush() s.clear() - + assert s.query(Post).get(post.id) is not None - + finally: s.rollback() - - -if __name__ == "__main__": - testbase.main() - + +if __name__ == "__main__": + testenv.main() diff --git a/test/zblog/user.py b/test/zblog/user.py index 3e77fa8427..973413d922 100644 --- a/test/zblog/user.py +++ b/test/zblog/user.py @@ -9,9 +9,10 @@ groups = [user, administrator] def cryptpw(password, salt=None): if salt is None: - salt = string.join([chr(random.randint(ord('a'), ord('z'))), chr(random.randint(ord('a'), ord('z')))],'') + salt = string.join([chr(random.randint(ord('a'), ord('z'))), + chr(random.randint(ord('a'), ord('z')))],'') return sha(password + salt).hexdigest() - + def checkpw(password, dbpw): return cryptpw(password, dbpw[:2]) == dbpw @@ -32,4 +33,4 @@ class User(object): password = property(lambda s: None, _set_password) def checkpw(self, password): - return checkpw(password, self.crypt_password) \ No newline at end of file + return checkpw(password, self.crypt_password)